From 7d3f8a09e4700c5378f97e83640ea20f83c23c4b Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 7 Apr 2022 04:03:57 +0000 Subject: [PATCH 01/44] [TUNER] Add attribute control for splitK --- .../Dialect/Flow/Transforms/SplitReduction.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp index 831d95b25850d..fec4ce04f5b13 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp @@ -24,6 +24,8 @@ namespace iree_compiler { namespace IREE { namespace Flow { +static const char kSplitKAttr[] = "iree_flow_split_k"; + // TODO(thomasraoux): Move to attributes. static llvm::cl::opt splitReductionRatio("iree-flow-split-matmul-reduction", @@ -83,19 +85,18 @@ struct SplitReductionPass : public SplitReductionBase { } void runOnOperation() override { - if (splitReductionRatio.getValue() <= 1 && - topkSplitReductionRatio.empty()) { - return; - } - RewritePatternSet patterns(&getContext()); patterns.add( &getContext(), [&](linalg::LinalgOp op) -> linalg::SplitReductionOptions { + int64_t ratio = splitReductionRatio; + if (auto attr = op->getAttrOfType(kSplitKAttr)) + ratio = attr.getInt(); + if (ratio <= 1) return {int64_t(0), 0, /*innerParallel=*/false}; // For matmul make the new parallel dimension first so that it looks // like a batch_matmul and can follow the same codegen. if (isa(op)) - return {int64_t(splitReductionRatio), 0, /*innerParallel=*/false}; + return {ratio, 0, /*innerParallel=*/false}; // Currently disable spliting reduction for non-matmul op. This will // get enabled after once tests are ready. return {int64_t(0), 0, /*innerParallel=*/false}; From 68f1c2109bce80ee14637b9ed9be5a4083a338a7 Mon Sep 17 00:00:00 2001 From: Nirvedh Date: Fri, 8 Apr 2022 21:05:45 +0000 Subject: [PATCH 02/44] [TUNER] Add attribute control for swizzle --- .../iree/compiler/Codegen/Common/GPU/WorkGroupSwizzle.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/WorkGroupSwizzle.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/WorkGroupSwizzle.cpp index 18bfe9b489a08..02e58e13807ac 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/WorkGroupSwizzle.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/WorkGroupSwizzle.cpp @@ -12,6 +12,8 @@ namespace mlir { namespace iree_compiler { +static const char kSwizzleAttr[] = "iree_swizzle"; + /// This function implements the following swizzling logic /// void getTiledId2(unsigned x, unsigned y, unsigned* tiledx, /// unsigned* tiledy) { @@ -110,6 +112,10 @@ struct WorkGroupSwizzlePass } void runOnOperation() override { func::FuncOp funcOp = getOperation(); + funcOp.walk([&](linalg::LinalgOp op) { + if (auto attr = op->getAttrOfType(kSwizzleAttr)) + swizzleLogTile = attr.getInt(); + }); (void)swizzleWorkgroupsInFunc(funcOp, swizzleLogTile); } From e0fbd64225750afa94254f64d5571409d3195aad Mon Sep 17 00:00:00 2001 From: yzhang93 Date: Tue, 11 Oct 2022 20:05:01 +0000 Subject: [PATCH 03/44] [TUNER] Allow split-k working on generic ops with reduction --- .../compiler/Dialect/Flow/Transforms/SplitReduction.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp index fec4ce04f5b13..494e4ecb08e19 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp @@ -97,6 +97,14 @@ struct SplitReductionPass : public SplitReductionBase { // like a batch_matmul and can follow the same codegen. if (isa(op)) return {ratio, 0, /*innerParallel=*/false}; + else if (isa(op)){ + SmallVector reductionDims; + op.getReductionDims(reductionDims); + if (reductionDims.empty()) + return {int64_t(0), 0, /*innerParallel=*/false}; + else + return {ratio, 0, /*innerParallel=*/false}; + } // Currently disable spliting reduction for non-matmul op. This will // get enabled after once tests are ready. return {int64_t(0), 0, /*innerParallel=*/false}; From 3d57c6161ea69aaf5158d2608a3f42f0ea9e7837 Mon Sep 17 00:00:00 2001 From: Eliasj42 <46754803+Eliasj42@users.noreply.github.com> Date: Sun, 5 Jun 2022 15:45:40 -0700 Subject: [PATCH 04/44] [CUDA] added fucntion to sync cuda context to current thread (#8) Co-authored-by: Elias Joseph --- runtime/src/iree/hal/drivers/cuda/cuda_device.c | 8 ++++++++ runtime/src/iree/hal/drivers/cuda/cuda_device.h | 2 ++ 2 files changed, 10 insertions(+) diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c index 136d729506088..0f39601610180 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c @@ -82,6 +82,14 @@ static iree_hal_cuda_device_t* iree_hal_cuda_device_cast_unsafe( return (iree_hal_cuda_device_t*)base_value; } +iree_status_t iree_cuda_set_current_thread(iree_hal_device_t* device){ + iree_hal_cuda_device_t* cuda_device = iree_hal_cuda_device_cast(device); + CUDA_RETURN_IF_ERROR(cuda_device->context_wrapper.syms, + cuCtxSetCurrent(cuda_device->context_wrapper.cu_context), + "cuCtxSetCurrent"); + return iree_ok_status(); +} + IREE_API_EXPORT void iree_hal_cuda_device_params_initialize( iree_hal_cuda_device_params_t* out_params) { memset(out_params, 0, sizeof(*out_params)); diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.h b/runtime/src/iree/hal/drivers/cuda/cuda_device.h index 0cc08870c6a0b..2af0c77c68bd5 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_device.h +++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.h @@ -41,6 +41,8 @@ CUcontext iree_hal_cuda_device_context(iree_hal_device_t* device); iree_hal_cuda_dynamic_symbols_t* iree_hal_cuda_device_dynamic_symbols( iree_hal_device_t* device); +iree_status_t iree_cuda_set_current_thread(iree_hal_device_t* device); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus From d122699e76e5657abca3128a6e3b21bd5b037a63 Mon Sep 17 00:00:00 2001 From: stanley Date: Mon, 31 Oct 2022 07:10:06 +0000 Subject: [PATCH 05/44] [vulkan] Combine broadcast+transfer_read on Vulkan and modify subspan --- .../Common/FlattenMemRefSubspanPass.cpp | 20 ++++++++ .../SPIRVTileAndVectorizeToCooperativeOps.cpp | 50 +++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp index 744b29304d894..e15916f7cd362 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp @@ -749,6 +749,25 @@ struct RemoveDynamicCastOp final : public OpRewritePattern { } }; +/// Removes memref.cast that turns dynamic shapes into static shapes. +struct RemoveStaticCastOp final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CastOp castOp, + PatternRewriter &rewriter) const override { + auto srcType = castOp.getSource().getType().cast(); + auto dstType = castOp.getType().cast(); + // Restrict to the cases we generate in this pass--1-D static shape to 1-D + // dynamic shape. + if (srcType.getRank() == 1 && !srcType.hasStaticShape() && + dstType.getRank() == 1 && dstType.hasStaticShape()) { + rewriter.replaceOp(castOp, castOp.getSource()); + return success(); + } + return failure(); + } +}; + //===----------------------------------------------------------------------===// // Pass //===----------------------------------------------------------------------===// @@ -907,6 +926,7 @@ struct FlattenMemRefSubspanPass memref::AllocaOp::getCanonicalizationPatterns(cleanupPatterns, context); memref::SubViewOp::getCanonicalizationPatterns(cleanupPatterns, context); cleanupPatterns.add(context); + cleanupPatterns.add(context); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(cleanupPatterns)))) { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp index 3f3e52d0954d5..3aa60b9b3eabe 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp @@ -303,6 +303,40 @@ class CombineContractTranspose final } }; +// Merge broadcast op into the transfer read op. Broadcast are not supported on +// MMA types. +struct CombineTransferReadOpBroadcast final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::BroadcastOp op, + PatternRewriter &rewriter) const override { + auto transferReadOp = + op.getSource().getDefiningOp(); + if (!transferReadOp || transferReadOp.getMask() || + transferReadOp.hasOutOfBoundsDim()) { + return failure(); + } + int64_t rankDiff = + op.getResultVectorType().getRank() - transferReadOp.getVectorType().getRank(); + SmallVector exprs(rankDiff, rewriter.getAffineConstantExpr(0)); + ArrayRef originalExpr = + transferReadOp.getPermutationMap().getResults(); + exprs.append(originalExpr.begin(), originalExpr.end()); + AffineMap newMap = + AffineMap::get(transferReadOp.getPermutationMap().getNumDims(), + transferReadOp.getPermutationMap().getNumSymbols(), + exprs, op.getContext()); + ArrayAttr inBounds = rewriter.getBoolArrayAttr( + SmallVector(op.getResultVectorType().getRank(), true)); + rewriter.replaceOpWithNewOp( + op, op.getType(), transferReadOp.getSource(), + transferReadOp.getIndices(), newMap, transferReadOp.getPadding(), + transferReadOp.getMask(), inBounds); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Main pass //===----------------------------------------------------------------------===// @@ -436,6 +470,22 @@ class SPIRVVectorizeToCooperativeOpsPass final llvm::dbgs() << "\n\n"; }); + { + RewritePatternSet combineBroadcastPatterns(context); + combineBroadcastPatterns.insert( + funcOp.getContext()); + if (failed(applyPatternsAndFoldGreedily( + funcOp, std::move(combineBroadcastPatterns)))) { + return signalPassFailure(); + } + } + + LLVM_DEBUG({ + llvm::dbgs() << "--- After combining transfer_read and broadcast ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + { RewritePatternSet vectorUnrollPatterns(context); populateVectorUnrollPatterns(cooperativeOpSize, vectorUnrollPatterns); From d70807209d5a00d1ca571633345d5943e2583074 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 29 Sep 2022 15:44:59 -0400 Subject: [PATCH 06/44] [Flow] Pass to convert NCHW convolutions to NHWC The conversion pass is enabled with `--iree-flow-enable-conv-nchw-to-nhwc-transform` Includes partial support for propagating and cancelling transposes generated when converting from nchw to nhwc. The high level strategy for this pass is as follows: 1. Do the conversions for all conv_nchw_fchw ops (and pooling ops) and wrap the converted convolutions in transposes. Each transpose is tagged to indicate which direction the transpose should propagate through the graph. 2. Traverse the ops in the function in reverse to propagate transposes marked for upwards propagation to their parents. Ideally just before ops such as arith.constant or function arguments. 3. Propagate the transposes marked for downward propagation to its users, ideally to just before return. 4. Canonicalize out all adjacent cancelling transposes and generalize the remaining transposes to allow for fusing them with nearby ops. --- .../compiler/Preprocessing/Common/BUILD.bazel | 1 + .../Preprocessing/Common/CMakeLists.txt | 1 + .../Common/ConvertConvNchwToNhwc.cpp | 564 ++++++++++++++++++ .../compiler/Preprocessing/Common/Passes.h | 7 +- .../compiler/Preprocessing/Common/Passes.td | 7 + .../Preprocessing/Common/test/BUILD.bazel | 1 + .../Preprocessing/Common/test/CMakeLists.txt | 1 + .../Common/test/conv2d_nchw_to_nhwc.mlir | 40 ++ 8 files changed, 621 insertions(+), 1 deletion(-) create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/test/conv2d_nchw_to_nhwc.mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index c9acfa23b3e58..eb1f9fabdc0d4 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -31,6 +31,7 @@ iree_compiler_cc_library( name = "Transforms", srcs = [ "ConvertConv2DToImg2Col.cpp", + "ConvertConvNchwToNhwc.cpp", "MakeSingleDispatchForFunction.cpp", "PadLinalgOps.cpp", "PassDetail.h", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index fb7b26ff5dcf7..b6ca3b0d3d5bd 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -27,6 +27,7 @@ iree_cc_library( "Passes.h.inc" SRCS "ConvertConv2DToImg2Col.cpp" + "ConvertConvNchwToNhwc.cpp" "MakeSingleDispatchForFunction.cpp" "PadLinalgOps.cpp" "PassDetail.h" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp new file mode 100644 index 0000000000000..5704f18ed4665 --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvNchwToNhwc.cpp @@ -0,0 +1,564 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Preprocessing/Common/PassDetail.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-flow-convert-conv-nchw-to-nhwc" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +using TransposeIndices = SmallVector; + +static const StringLiteral transposeEmptyMarker = "__nchw_to_nhwc_init__"; +static const StringLiteral transposePropagateUpMarker = "__nchw_to_nhwc_up__"; +static const StringLiteral transposePropagateDownMarker = + "__nchw_to_nhwc_down__"; + +static TransposeIndices invertIndices(TransposeIndices targetIndices) { + auto rank = targetIndices.size(); + TransposeIndices inverted(rank); + for (auto i : llvm::enumerate(targetIndices)) { + inverted[i.value()] = i.index(); + } + return inverted; +} + +static TransposeIndices getTransposeIndices(linalg::TransposeOp op) { + return llvm::to_vector(op.getPermutation()); +} + +static bool isStaticallyShaped(Value input) { + if (auto inputType = input.getType().dyn_cast()) + return inputType.hasStaticShape(); + return false; +} + +// Get the transpose indices if the given input comes from a transpose and is +// marked to propagate down. +static std::optional getIndicesFromInput(Value input) { + if (!isStaticallyShaped(input)) return std::nullopt; + auto parent = input.getDefiningOp(); + if (parent && parent->hasAttr(transposePropagateDownMarker)) + return getTransposeIndices(parent); + return std::nullopt; +} + +// Get the transpose indices if the given output is used by at least one +// transpose and that transpose is marked to propagate up. Additionally don't +// propagate if there are conflicting transposes. +static std::optional getIndicesFromOutput(Value output) { + if (!isStaticallyShaped(output)) return std::nullopt; + std::optional transposedOut; + if (llvm::all_of(output.getUses(), [&transposedOut](const OpOperand &use) { + auto owner = dyn_cast(use.getOwner()); + if (owner && owner->hasAttr(transposePropagateUpMarker)) { + if (transposedOut.has_value()) { + if (getTransposeIndices(transposedOut.value()) == + getTransposeIndices(owner)) + return true; + return false; + } + transposedOut = owner; + return true; + } + return false; + })) { + if (transposedOut.has_value()) + return getTransposeIndices(transposedOut.value()); + } + return std::nullopt; +} + +// Helper to shuffle vectors according to the transpose indices. +template +static SmallVector shuffleFromIndices(SmallVector unshuffled, + TransposeIndices targetIndices) { + auto rank = unshuffled.size(); + assert(targetIndices.size() == rank && + "Mismatch between number of elements in input and number of indices"); + SmallVector shuffled(rank); + + for (auto i : llvm::enumerate(targetIndices)) { + shuffled[i.index()] = unshuffled[i.value()]; + } + return shuffled; +} + +// Transpose the given tensor based on the given transpose indices. Marks the +// created transpose based on the propagation direction. +static Value createTranspose(PatternRewriter &rewriter, Location loc, + Value input, TransposeIndices targetIndices, + bool propagateUp) { + RankedTensorType inType = input.getType().cast(); + auto elementType = inType.getElementType(); + auto inputShape(inType.getShape()); + + auto outputShape = + shuffleFromIndices(llvm::to_vector(inputShape), targetIndices); + + Value output = + rewriter.create(loc, outputShape, elementType); + output.getDefiningOp()->setAttr(transposeEmptyMarker, rewriter.getUnitAttr()); + + auto transpose = + rewriter.create(loc, input, output, targetIndices); + transpose->setAttr( + propagateUp ? transposePropagateUpMarker : transposePropagateDownMarker, + rewriter.getUnitAttr()); + return transpose.getResults()[0]; +} + +// Supports conv and pooling ops, where pooling ops don't transpose the filter. +template +static LogicalResult convertConvLikeNchwToNhwc(PatternRewriter &rewriter, + ConvOpTy convOp, + bool transposeFilter) { + LLVM_DEBUG(llvm::dbgs() << "inspecting " << convOp << "\n"); + + Location loc = convOp.getLoc(); + + Value input = convOp.image(); + Value filter = convOp.filter(); + Value output = convOp.getOutputs()[0]; + + if (!isStaticallyShaped(input) || !isStaticallyShaped(output) || + (transposeFilter && !isStaticallyShaped(filter))) { + return failure(); + } + + TransposeIndices NCHWIndices = {0, 2, 3, 1}; + + auto transposedInput = + createTranspose(rewriter, loc, input, NCHWIndices, true); + auto transposedFilter = filter; + if (transposeFilter) { + TransposeIndices FCHWIndices = {2, 3, 1, 0}; + transposedFilter = + createTranspose(rewriter, loc, filter, FCHWIndices, true); + } + auto transposedOutput = + createTranspose(rewriter, loc, output, NCHWIndices, true); + + auto conv = + rewriter + .create(loc, transposedOutput.getType(), + ValueRange{transposedInput, transposedFilter}, + transposedOutput, convOp.getStrides(), + convOp.getDilations()) + .getResult(0); + + auto returnToNCHW = + createTranspose(rewriter, loc, conv, invertIndices(NCHWIndices), false); + + rewriter.replaceOp(convOp, returnToNCHW); + return success(); +} + +namespace { + +/* + * Convolution conversion patterns + */ + +struct ConvertLinalgConvNchwFchw : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp, + PatternRewriter &rewriter) const override { + return convertConvLikeNchwToNhwc(rewriter, convOp, + true); + } +}; + +struct ConvertLinalgPoolingNchwMax + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::PoolingNchwMaxOp poolOp, + PatternRewriter &rewriter) const override { + return convertConvLikeNchwToNhwc(rewriter, poolOp, + false); + } +}; + +struct ConvertLinalgPoolingNchwSum + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::PoolingNchwSumOp poolOp, + PatternRewriter &rewriter) const override { + return convertConvLikeNchwToNhwc(rewriter, poolOp, + false); + } +}; + +/* + * Transpose propagation patterns + */ + +struct PropagateThroughTensorPadPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PropagateThroughTensorPadPattern(MLIRContext *context, bool propagateUp) + : OpRewritePattern(context), propagateUp(propagateUp) {} + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + TransposeIndices transposeIndices; + + if (propagateUp) { + auto indices = getIndicesFromOutput(padOp.getResult()); + if (!indices.has_value()) return failure(); + transposeIndices = indices.value(); + } else { + auto indices = getIndicesFromInput(padOp.getSource()); + if (!indices.has_value()) return failure(); + transposeIndices = invertIndices(indices.value()); + } + + LLVM_DEBUG(llvm::dbgs() << "propagating " << padOp << "\n"); + + Location loc = padOp.getLoc(); + + auto input = padOp.getSource(); + SmallVector mixedLow = shuffleFromIndices( + padOp.getMixedLowPad(), transposeIndices); + SmallVector mixedHigh = shuffleFromIndices( + padOp.getMixedHighPad(), transposeIndices); + + auto transposedInput = + createTranspose(rewriter, loc, input, transposeIndices, true); + + SmallVector outputShape(padOp.getResultType().getShape()); + SmallVector transposedOutputShape = + shuffleFromIndices(outputShape, transposeIndices); + RankedTensorType transposedOutputType = RankedTensorType::get( + transposedOutputShape, padOp.getResultType().getElementType()); + + auto newPad = rewriter.create(loc, transposedOutputType, + transposedInput, mixedLow, + mixedHigh, padOp.getNofold()); + IRMapping mapper; + padOp.getRegion().cloneInto(&newPad.getRegion(), mapper); + + auto returnToNCHW = createTranspose(rewriter, loc, newPad.getResult(), + invertIndices(transposeIndices), false); + + rewriter.replaceOp(padOp, returnToNCHW); + return success(); + } + + private: + bool propagateUp; +}; + +struct PropagateThroughLinalgFillPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PropagateThroughLinalgFillPattern(MLIRContext *context, bool propagateUp) + : OpRewritePattern(context), propagateUp(propagateUp) {} + + LogicalResult matchAndRewrite(linalg::FillOp fillOp, + PatternRewriter &rewriter) const override { + TransposeIndices transposeIndices; + + if (propagateUp) { + auto indices = getIndicesFromOutput(fillOp.getResult(0)); + if (!indices.has_value()) return failure(); + transposeIndices = indices.value(); + } else { + auto indices = getIndicesFromInput(fillOp.value()); + if (!indices.has_value()) return failure(); + transposeIndices = invertIndices(indices.value()); + } + + LLVM_DEBUG(llvm::dbgs() << "propagating " << fillOp << "\n"); + Location loc = fillOp.getLoc(); + + auto transposedOutput = + createTranspose(rewriter, loc, fillOp.output(), transposeIndices, true); + + auto newTensor = + rewriter.create(loc, fillOp.value(), transposedOutput) + .getResult(0); + + auto returnToNCHW = createTranspose(rewriter, loc, newTensor, + invertIndices(transposeIndices), false); + + rewriter.replaceOp(fillOp, returnToNCHW); + return success(); + } + + private: + bool propagateUp; +}; + +struct PropagateThroughLinalgGenericPattern + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PropagateThroughLinalgGenericPattern(MLIRContext *context, bool propagateUp) + : OpRewritePattern(context), + propagateUp(propagateUp) {} + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter &rewriter) const override { + TransposeIndices transposeIndices; + + // For now restrict to single results. + if (genericOp.getNumResults() != 1) return failure(); + + if (propagateUp) { + auto indices = getIndicesFromOutput(genericOp.getOutputs()[0]); + if (!indices.has_value()) return failure(); + transposeIndices = indices.value(); + } else { + // TODO: Enable directly fusing the transpose with the inputs. + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "propagating " << genericOp << "\n"); + + Location loc = genericOp.getLoc(); + + auto transposedOutput = genericOp.getOutputs()[0]; + auto indexingMaps = genericOp.getIndexingMapsArray(); + + if (propagateUp) { + transposedOutput = createTranspose(rewriter, loc, transposedOutput, + transposeIndices, true); + + AffineMap outMap = indexingMaps.back(); + SmallVector outExprs(outMap.getResults()); + SmallVector exprs = + shuffleFromIndices(outExprs, transposeIndices); + indexingMaps[indexingMaps.size() - 1] = + AffineMap::get(outMap.getNumDims(), outMap.getNumSymbols(), exprs, + genericOp->getContext()); + } + + SmallVector newInputs; + for (auto input : llvm::enumerate(genericOp.getInputs())) { + newInputs.push_back(input.value()); + } + + SmallVector iteratorTypes = + genericOp.getIteratorTypesArray(); + + auto newGeneric = rewriter.create( + loc, transposedOutput.getType().cast(), newInputs, + transposedOutput, indexingMaps, iteratorTypes); + IRMapping mapper; + genericOp.getRegion().cloneInto(&newGeneric.getRegion(), mapper); + + Value returnToNCHW = newGeneric.getResult(0); + if (propagateUp) { + returnToNCHW = createTranspose(rewriter, loc, returnToNCHW, + invertIndices(transposeIndices), false); + } + + rewriter.replaceOp(genericOp, returnToNCHW); + return success(); + } + + private: + bool propagateUp; +}; + +struct PropagateThroughTensorEmptyPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::EmptyOp emptyOp, + PatternRewriter &rewriter) const override { + if (emptyOp->hasAttr(transposeEmptyMarker)) return failure(); + TransposeIndices transposeIndices; + + auto indices = getIndicesFromOutput(emptyOp.getResult()); + if (!indices.has_value()) return failure(); + transposeIndices = indices.value(); + + LLVM_DEBUG(llvm::dbgs() << "propagating " << emptyOp << "\n"); + + Location loc = emptyOp.getLoc(); + + SmallVector mixedSizes = shuffleFromIndices( + emptyOp.getMixedSizes(), transposeIndices); + + auto newTensor = rewriter.create( + loc, mixedSizes, emptyOp.getType().getElementType()); + auto returnToNCHW = createTranspose(rewriter, loc, newTensor.getResult(), + invertIndices(transposeIndices), false); + + rewriter.replaceOp(emptyOp, returnToNCHW); + return success(); + } +}; + +/* + * Folding away cancelling transposes and generalizing + */ + +// Cancel if this transpose is tagged with a propagating tag and the defining op +// for the input is the inverse of this transpose +struct CancelNCHWToNHWCTransposePattern + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + auto transposeIndices = invertIndices(getTransposeIndices(transposeOp)); + + auto parentOp = + transposeOp->getOperand(0).getDefiningOp(); + if (parentOp) { + if (getTransposeIndices(parentOp) == transposeIndices) { + rewriter.replaceOp(transposeOp, parentOp->getOperand(0)); + return success(); + } + } + + return failure(); + } +}; + +struct GeneralizeTransposeOpPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + if (transposeOp->hasAttr(transposePropagateUpMarker) || + transposeOp->hasAttr(transposePropagateDownMarker)) { + auto context = rewriter.getContext(); + auto rank = + transposeOp.getResultTypes()[0].cast().getRank(); + + auto transposeIndices = getTransposeIndices(transposeOp); + + SmallVector idExprs; + for (auto i = 0; i < rank; i++) + idExprs.push_back(getAffineDimExpr(i, context)); + + SmallVector swapExprs = + shuffleFromIndices(idExprs, transposeIndices); + + SmallVector indexingMaps = { + AffineMap::get(rank, 0, idExprs, context), + AffineMap::get(rank, 0, swapExprs, context)}; + SmallVector iteratorTypes( + rank, utils::IteratorType::parallel); + + rewriter.replaceOpWithNewOp( + transposeOp, transposeOp.getResultTypes()[0], + transposeOp.getOperand(0), transposeOp.getOperand(1), indexingMaps, + iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }); + return success(); + } + return failure(); + } +}; + +// The high level strategy for this pass is as follows: +// 1. Do the conversions for all conv_nchw_fchw ops (and pooling ops) and +// wrap the converted convolutions in transposes. Each transpose is tagged +// to indicate which direction the transpose should propagate through the +// graph. +// 2. Traverse the ops in the function in reverse to propagate transposes +// marked for upwards propagation to their parents. Ideally just before ops +// such as arith.constant or function arguments. +// 3. Propagate the transposes marked for downward propagation to its users, +// ideally to just before return. +// 4. Canonicalize out all adjacent cancelling transposes and generalize the +// remaining transposes to allow for fusing them with nearby ops. +struct ConvertConvNchwToNhwcPass + : public ConvertConvNchwToNhwcBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + Operation *funcOp = getOperation(); + MLIRContext *context = &getContext(); + + { + RewritePatternSet patterns(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + } + + // Propagate transposes up the graph. + { + SmallVector ops; + funcOp->walk([&](Operation *op) { ops.push_back(op); }); + + RewritePatternSet patterns(context); + patterns.insert(context, true); + patterns.insert(context); + patterns.insert(context, true); + patterns.insert(context, true); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + + SmallVector reverseOps(llvm::reverse(ops)); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::AnyOp; + (void)applyOpPatternsAndFold(reverseOps, frozenPatterns, config); + } + + // Propagate transposes down the graph. + { + RewritePatternSet patterns(context); + patterns.insert(context, false); + patterns.insert(context, false); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + + // Cancel out transposes. + { + RewritePatternSet patterns(context); + patterns.insert(context); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + + // Generalize remaining transposes to allow fusion with other ops. + { + RewritePatternSet patterns(context); + patterns.insert(context); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + } +}; + +} // namespace + +std::unique_ptr> +createConvertConvNchwToNhwcPass() { + return std::make_unique(); +} + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h index 7f78aef82a159..62181f1d2ab47 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h @@ -10,6 +10,7 @@ #include #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -21,11 +22,15 @@ namespace IREE { /// using im2col tranformation. std::unique_ptr createConvertConv2DToImg2ColPass(); +// Creates a pass to convert linalg NCHW Convolutions to NHWC. +std::unique_ptr> +createConvertConvNchwToNhwcPass(); + /// Moves the body of the entire function into a single dispatch. std::unique_ptr> createMakeSingleDispatchForFunctionPass(); -/// A pass to pad linalg ops to the next integer multiple of `paddingSize`. +// A pass to pad linalg ops to the next integer multiple of `paddingSize`. std::unique_ptr createPadLinalgOpsToIntegerMultiplePass(); //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index 3a26cc7bc49ff..2b426721aff24 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -15,6 +15,13 @@ def ConvertConv2DToImg2Col : let constructor = "mlir::iree_compiler::IREE::createConvertConv2DToImg2ColPass()"; } +def ConvertConvNchwToNhwc : + InterfacePass<"iree-flow-convert-conv-nchw-to-nhwc", "mlir::FunctionOpInterface"> { + let summary = "Convert linalg NCHW Convolutions to NHWC"; + let constructor = + "mlir::iree_compiler::IREE::createConvertConvNchwToNhwcPass()"; +} + def MakeSingleDispatchForFunction : Pass<"iree-preprocessing-make-single-dispatch-for-function", "func::FuncOp"> { let summary = "Convert entire function into a single dispatch"; diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel index d11313bc5a98d..044ffc3293262 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel @@ -16,6 +16,7 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "conv2d_nchw_to_nhwc.mlir", "conv2d_to_img2col.mlir", "make_single_dispatch_for_function.mlir", "pad_linalg_ops.mlir", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt index 19425cb4944d0..c47bc7b0ea22f 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "conv2d_nchw_to_nhwc.mlir" "conv2d_to_img2col.mlir" "make_single_dispatch_for_function.mlir" "pad_linalg_ops.mlir" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/conv2d_nchw_to_nhwc.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/conv2d_nchw_to_nhwc.mlir new file mode 100644 index 0000000000000..b7ab1af35a699 --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/conv2d_nchw_to_nhwc.mlir @@ -0,0 +1,40 @@ +// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc))" %s | FileCheck %s + +func.func @batch_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> { + %0 = linalg.conv_2d_nchw_fchw + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) + outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> + return %0 : tensor<8x16x14x14xf32> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1, d0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)> +// CHECK: @batch_conv +// CHECK: %[[INPUT:.+]]: tensor<8x4x16x16xf32> +// CHECK: %[[FILTER:.+]]: tensor<16x4x3x3xf32> +// CHECK: %[[OUTPUT:.+]]: tensor<8x16x14x14xf32> +// CHECK: %[[INIT_INPUT_TRANSPOSE:.+]] = tensor.empty() {__nchw_to_nhwc_init__} : tensor<8x16x16x4xf32> +// CHECK: %[[TRANSPOSED_INPUT:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: #[[MAP1]] +// CHECK-SAME: ins(%[[INPUT]] : tensor<8x4x16x16xf32>) outs(%[[INIT_INPUT_TRANSPOSE]] : tensor<8x16x16x4xf32>) +// CHECK: %[[INIT_FILTER_TRANSPOSE:.+]] = tensor.empty() {__nchw_to_nhwc_init__} : tensor<3x3x4x16xf32> +// CHECK: %[[TRANSPOSED_FILTER:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: #[[MAP2]] +// CHECK-SAME: ins(%[[FILTER]] : tensor<16x4x3x3xf32>) outs(%[[INIT_FILTER_TRANSPOSE]] : tensor<3x3x4x16xf32>) +// CHECK: %[[INIT_OUTPUT_TRANSPOSE:.+]] = tensor.empty() {__nchw_to_nhwc_init__} : tensor<8x14x14x16xf32> +// CHECK: %[[TRANSPOSED_OUTPUT:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: #[[MAP1]] +// CHECK-SAME: ins(%[[OUTPUT]] : tensor<8x16x14x14xf32>) outs(%[[INIT_OUTPUT_TRANSPOSE]] : tensor<8x14x14x16xf32>) +// CHECK: %[[TRANSPOSED_RESULT:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[TRANSPOSED_INPUT]], %[[TRANSPOSED_FILTER]] : tensor<8x16x16x4xf32>, tensor<3x3x4x16xf32>) outs(%[[TRANSPOSED_OUTPUT]] : tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> +// CHECK: %[[INIT_RESULT:.+]] = tensor.empty() {__nchw_to_nhwc_init__} : tensor<8x16x14x14xf32> +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK-SAME: #[[MAP3]] +// CHECK-SAME: ins(%[[TRANSPOSED_RESULT]] : tensor<8x14x14x16xf32>) outs(%[[INIT_RESULT]] : tensor<8x16x14x14xf32>) +// CHECK: return %[[RESULT]] : tensor<8x16x14x14xf32> From 27e34822fb2062db5b88d330ed0856a35ad3c7ad Mon Sep 17 00:00:00 2001 From: stanley-nod Date: Fri, 28 Oct 2022 17:45:51 -0700 Subject: [PATCH 07/44] [codegen][spirv] Pack/transpose matrix B for better coop mmma --- .../Codegen/SPIRV/SPIRVTileAndDistribute.cpp | 4 +- .../Flow/Transforms/FormDispatchRegions.cpp | 9 +- .../Dialect/Flow/Transforms/Passes.cpp | 6 + .../compiler/Preprocessing/Common/BUILD.bazel | 4 +- .../Preprocessing/Common/CMakeLists.txt | 2 + .../Common/ConvertLinalgMatmulToMmt.cpp | 119 ++++++++++++++ .../Common/GeneralizeAndFuse.cpp | 148 ++++++++++++++++++ .../compiler/Preprocessing/Common/Passes.h | 6 + .../compiler/Preprocessing/Common/Passes.td | 12 ++ 9 files changed, 306 insertions(+), 4 deletions(-) create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp index 260f8bb7857ff..7df6629b0b4f9 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp @@ -98,8 +98,8 @@ static void populateTilingReductionPatterns( .setLoopType(linalg::LinalgTilingLoopType::Loops) .setTileSizeComputationFunction(computeFn); - TilingPatterns::insert( - patterns, tilingOptions, filter); + TilingPatterns::insert(patterns, tilingOptions, filter); filter.addFilter([](Operation *op) { return success(isa(op)); }); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp index ea00c12bb53a5..fbc66e9175d29 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp @@ -653,7 +653,14 @@ isFusableWithProducer(OpOperand &operand, } auto consumerLinalgOp = cast(consumer); - if (!consumerLinalgOp.isDpsInit(&operand)) { + if (consumerLinalgOp.isDpsInput(&operand)) { + // TODO: Add some marker on transpose and MatmulOp to indicate mmt. + bool fuseTransposeAndMatmul = + isa(consumer) && isa(producer); + if (fuseTransposeAndMatmul) { + return true; + } + } else if (!consumerLinalgOp.isDpsInit(&operand)) { return false; } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index 524977e31d43b..b8394b36f7b08 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -79,6 +79,12 @@ static llvm::cl::opt clDispatchGenerateWorkloadRegion( "iree-flow-dispatch-generate-workload-region", llvm::cl::desc("Generate the workload region."), llvm::cl::init(true)); + +static llvm::cl::opt clEnableTransposeMatmulLayout( + "iree-flow-enable-transpose-matmul-layout", + llvm::cl::desc("Enable transposing the B matrix for matmuls."), + llvm::cl::init(false)); + static llvm::cl::opt clNormalizeInputIndexingMap( "iree-flow-normalize-input-indexing-map", llvm::cl::desc("Enable normalizing input indexing map to identity."), diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index eb1f9fabdc0d4..f8582d347074c 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -32,8 +32,10 @@ iree_compiler_cc_library( srcs = [ "ConvertConv2DToImg2Col.cpp", "ConvertConvNchwToNhwc.cpp", + "ConvertLinalgMatmulToMmt.cpp", + "GeneralizeAndFuse.cpp", "MakeSingleDispatchForFunction.cpp", - "PadLinalgOps.cpp", + "PadLinalgOps.cpp", "PassDetail.h", "Passes.cpp", ], diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index b6ca3b0d3d5bd..b5bd22be2d721 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -28,6 +28,8 @@ iree_cc_library( SRCS "ConvertConv2DToImg2Col.cpp" "ConvertConvNchwToNhwc.cpp" + "ConvertLinalgMatmulToMmt.cpp" + "GeneralizeAndFuse.cpp" "MakeSingleDispatchForFunction.cpp" "PadLinalgOps.cpp" "PassDetail.h" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp new file mode 100644 index 0000000000000..f22e55cf92ac0 --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp @@ -0,0 +1,119 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/compiler/Preprocessing/Common/PassDetail.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +namespace { + +// Converts linalg.matmul to an linalg.transpose + linalg.matmul. +// Such that matrix B layout changes to col major. +class LinalgMatmulOpToLinalgMmtPattern final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp, + PatternRewriter &rewriter) const override { + Location loc = matmulOp.getLoc(); + Value lhs = matmulOp.getDpsInputOperand(0)->get(); + Value rhs = matmulOp.getDpsInputOperand(1)->get(); + Value acc = matmulOp.getDpsInitOperand(0)->get(); + if (dyn_cast(rhs.getDefiningOp())) { + return failure(); + } + auto rhsType = rhs.getType().cast(); + auto rhsShape = rhsType.getShape(); + auto rhsElemType = rhsType.getElementType(); + SmallVector transposedRhsShape = {rhsShape[1], rhsShape[0]}; + + // GenericOp + int64_t nloops = rhsShape.size(); + AffineExpr mDim, nDim; + bindDims(getContext(), mDim, nDim); + auto inputMap = AffineMap::get(2, 0, {mDim, nDim}, getContext()); + auto packedMap = AffineMap::get(2, 0, {nDim, mDim}, getContext()); + SmallVector indexingMaps = {inputMap, packedMap}; + + Value transposedRhs = + rewriter.create(loc, transposedRhsShape, rhsElemType); + SmallVector loopAttributeTypes( + nloops, utils::IteratorType::parallel); + + Value packedRhs = + rewriter + .create( + loc, transposedRhs.getType(), + /*inputs=*/rhs, /*outputs=*/transposedRhs, indexingMaps, + loopAttributeTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }) + .getResult(0); + + // TransposeOp + Value initOp = rewriter.create(loc, rhsShape, rhsElemType); + SmallVector transposedPerm = {1, 0}; + Value transposePackedRhs = + rewriter + .create(loc, packedRhs, initOp, transposedPerm) + .getResults()[0]; + + // MatmulOp + Value packedMatmul = + rewriter + .create(loc, matmulOp.getResult(0).getType(), + ArrayRef{lhs, transposePackedRhs}, + ArrayRef{acc}) + .getResult(0); + rewriter.replaceOp(matmulOp, packedMatmul); + return success(); + } +}; + +struct ConvertLinalgMatmulToMmtPass + : public ConvertLinalgMatmulToMmtBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + // Main pattern. + { + RewritePatternSet patterns(&getContext()); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } + } +}; +} // namespace + +std::unique_ptr createConvertLinalgMatmulToMmtPass() { + return std::make_unique(); +} + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp new file mode 100644 index 0000000000000..8a4e7bee31e3d --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp @@ -0,0 +1,148 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/compiler/Preprocessing/Common/PassDetail.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +namespace { + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +bool isMatmulOrBatchMatmul(linalg::LinalgOp linalgOp) { + return linalg::isaContractionOpInterface(linalgOp) && + llvm::is_contained({2u, 3u}, linalgOp.getNumParallelLoops()); +} + +//===----------------------------------------------------------------------===// +// Generalize and fusion patterns. +//===----------------------------------------------------------------------===// + +struct GeneralizeAndFusePass + : public GeneralizeAndFuseBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + template + class GeneralizeTargetNamedOpPattern final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LinalgOpType linalgOp, + PatternRewriter &rewriter) const override { + // TODO: Check consumer is transposeOp. + // TODO: Generalize transpos + FailureOr genericOp = + linalg::generalizeNamedOp(rewriter, linalgOp); + if (failed(genericOp)) return failure(); + return success(); + } + }; + + class FuseMatmulAndTranspose final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + // Inspo: + // https://github.com/llvm/llvm-project/blob/4f1c12425179608298dc39f5524ba2612609b5e4/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp + LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, + PatternRewriter &rewriter) const override { + const unsigned rhsId = 1; + if (!isMatmulOrBatchMatmul(linalgOp)) return failure(); + Value rhs = linalgOp.getDpsInputOperand(rhsId)->get(); + auto transposeOp = dyn_cast(rhs.getDefiningOp()); + if (!transposeOp) return failure(); + auto perm = transposeOp.getPermutation(); + auto indexingMaps = linalgOp.getIndexingMaps(); + auto rhsMap = indexingMaps[rhsId].cast().getValue(); + int64_t rank = perm.size(); + if (rhsMap.getNumResults() != rank) return failure(); + SmallVector exprs; + for (auto dim_id : perm) { + exprs.push_back(rhsMap.getResult(dim_id)); + } + AffineMap transposedRhsMap = + AffineMap::get(rhsMap.getNumDims(), 0, exprs, getContext()); + + // TODO: Fold transposeOp as transposed indexing for matmulOp. + // Generate a map set. + auto lhsMap = indexingMaps[0].cast().getValue(); + auto accMap = indexingMaps[2].cast().getValue(); + SmallVector newIndexingMaps = {lhsMap, transposedRhsMap, + accMap}; + + // Generate new list of args. + Value newRhs = transposeOp.getDpsInputOperand(0)->get(); + Value lhs = linalgOp.getDpsInputOperand(0)->get(); + Value acc = linalgOp.getDpsInitOperand(0)->get(); + SmallVector inputs = {lhs, newRhs}; + + // Generate a new genericOp. + linalg::GenericOp genericOp = rewriter.create( + linalgOp.getLoc(), linalgOp.getResultTypes(), /*inputs*/ inputs, + /*outputs*/ acc, newIndexingMaps, linalgOp.getIteratorTypesArray()); + // Block consumerBlock = linalgOp->getRegion(0).front(); + // genericOp.getRegion().push_back(consumerBlock); + // llvm::outs()<<"new op + // regions:"<getNumRegions()<<"\n"; + // llvm::outs()<<"new op + // regions:"<getNumRegions()<<"\n"; + // llvm::outs()<<"new op + // blocks:"<getNumRegions()<<"\n"; + // llvm::outs()<<"old op + // blocks:"<getRegion(0), genericOp.getRegion(), + genericOp.getRegion().begin()); + rewriter.replaceOp(linalgOp, genericOp->getResults()); + return success(); + } + }; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + // Main pattern. + // Generalize + Fuse pattern. + { + RewritePatternSet patterns(&getContext()); + patterns.insert, + FuseMatmulAndTranspose>(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } + } +}; +} // namespace + +std::unique_ptr createGeneralizeAndFusePass() { + return std::make_unique(); +} + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h index 62181f1d2ab47..475f8bda54f2d 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h @@ -26,6 +26,12 @@ std::unique_ptr createConvertConv2DToImg2ColPass(); std::unique_ptr> createConvertConvNchwToNhwcPass(); +// Pass to convert a linalg.matmul into linalg.transpose + linalg.matmul. +std::unique_ptr createConvertLinalgMatmulToMmtPass(); + +// Generalizes named op and try to fuse them +std::unique_ptr createGeneralizeAndFusePass(); + /// Moves the body of the entire function into a single dispatch. std::unique_ptr> createMakeSingleDispatchForFunctionPass(); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index 2b426721aff24..0e0bc849675fa 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -22,6 +22,18 @@ def ConvertConvNchwToNhwc : "mlir::iree_compiler::IREE::createConvertConvNchwToNhwcPass()"; } +def ConvertLinalgMatmulToMmt : + Pass<"iree-flow-convert-linalg-matmul-to-mmt", ""> { + let summary = "Convert linalg.matmul to linalg.transpose + linalg.matmul"; + let constructor = "mlir::iree_compiler::IREE::createConvertLinalgMatmulToMmtPass()"; +} + +def GeneralizeAndFuse : + Pass<"iree-flow-generalize-and-fuse", ""> { + let summary = "Generalizes named op and try to fuse them."; + let constructor = "mlir::iree_compiler::IREE::createGeneralizeAndFusePass()"; +} + def MakeSingleDispatchForFunction : Pass<"iree-preprocessing-make-single-dispatch-for-function", "func::FuncOp"> { let summary = "Convert entire function into a single dispatch"; From a1b84efe76a8baa41b61bc336fded8f1e5aad415 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 12 Dec 2022 13:24:05 -0500 Subject: [PATCH 08/44] [winograd] Add winograd convolution attribute control as iree_winograd_conv --- .../Dialect/LinalgExt/Passes/Passes.h | 2 +- .../Passes/ConvertConv2DToWinograd.cpp | 28 +++++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h index 6ef8bb3f06610..a4fd87c824ac9 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h @@ -165,7 +165,7 @@ createTileAndDecomposeWinogradTransformPass(); // Creates a pass to convert linalg convolution ops into a sequence of // linalg_ext.winograd.* ops and linalg.batch_matmul ops using the winograd // tranformation. -std::unique_ptr createConvertConv2DToWinogradPass(); +std::unique_ptr createConvertConv2DToWinogradPass(bool forceWinograd = false); // Creates a pass to convert the softmax op into a sequence of // linalg generic ops. diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp index 0680be4d8e8e5..bece710a76bd3 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp @@ -30,6 +30,8 @@ namespace iree_compiler { namespace IREE { namespace LinalgExt { +static const char winogradAttr[] = "iree_winograd_conv"; + static inline int index(int y, int x, int dimy, int dimx) { return (x + dimx * y); } @@ -134,10 +136,16 @@ template class FoldWinogradFilterTransform final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + FoldWinogradFilterTransform(MLIRContext *context, bool force) + : OpRewritePattern(context, force), forceWinograd(force) {} LogicalResult matchAndRewrite(ConvOp convOp, PatternRewriter &rewriter) const override { + // Attribute control unless forced. + if (!forceWinograd && !convOp->hasAttr(winogradAttr)) + return failure(); + bool isNchw; if (!isValidConv2d(convOp, isNchw)) return failure(); @@ -187,6 +195,8 @@ class FoldWinogradFilterTransform final : public OpRewritePattern { rewriter.replaceOpWithNewOp(constOp, foldedKernelAttr); return success(); } + private: + bool forceWinograd; }; } // namespace @@ -283,10 +293,16 @@ template class ConvertConvToWinograd final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + ConvertConvToWinograd(MLIRContext *context, bool force) + : OpRewritePattern(context, force), forceWinograd(force) {} LogicalResult matchAndRewrite(ConvOp convOp, PatternRewriter &rewriter) const override { + // Attribute control unless forced. + if (!forceWinograd && !convOp->hasAttr(winogradAttr)) + return failure(); + bool isNchw; if (!isValidConv2d(convOp, isNchw)) return failure(); @@ -416,10 +432,14 @@ class ConvertConvToWinograd final : public OpRewritePattern { result.replaceAllUsesWith(winogradOutput); return success(); } + private: + bool forceWinograd; }; struct ConvertConv2DToWinogradPass : ConvertConv2DToWinogradBase { + public: + ConvertConv2DToWinogradPass(bool forceWinograd) : forceWinograd(forceWinograd) {} void getDependentDialects(DialectRegistry ®istry) const override { registry .insert(); @@ -430,18 +450,20 @@ struct ConvertConv2DToWinogradPass patterns.insert, FoldWinogradFilterTransform, ConvertConvToWinograd, - ConvertConvToWinograd>(context); + ConvertConvToWinograd>(context, forceWinograd); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } + private: + bool forceWinograd; }; } // namespace -std::unique_ptr createConvertConv2DToWinogradPass() { - return std::make_unique(); +std::unique_ptr createConvertConv2DToWinogradPass(bool forceWinograd) { + return std::make_unique(forceWinograd); } } // namespace LinalgExt From 85fb0176aa2779e4ad17db1f02df07f7fbd9b297 Mon Sep 17 00:00:00 2001 From: Harsh Date: Tue, 13 Dec 2022 08:59:56 -0800 Subject: [PATCH 09/44] [Winograd] Winograd improvements - Speedup filter transform folding - Add points for 4x4, switch to that tile size - Move winograd after im2col + padding, in im2col do not touch conv if it has been marked as winograd -remove prints/chrono and adjust Attribute rawKernelAttr for windows by Quinn Co-authored-by: Quinn Dawkins --- .../Common/ConvertConv2DToImg2Col.cpp | 8 +++ .../LinalgExt/Utils/WinogradConstants.h | 56 +++++++++++++++++++ .../Passes/ConvertConv2DToWinograd.cpp | 45 +++++++++++---- .../Passes/TileAndDecomposeWinogradPass.cpp | 8 +++ 4 files changed, 107 insertions(+), 10 deletions(-) diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp index 12d9c495b24d1..4b6307c4d1bab 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp @@ -23,6 +23,8 @@ namespace mlir { namespace iree_compiler { namespace IREE { +static const char winogradAttr[] = "iree_winograd_conv"; + static bool hasAllOneValues(DenseIntElementsAttr attr) { return llvm::all_of( attr, [](APInt element) { return element.getSExtValue() == 1; }); @@ -96,6 +98,9 @@ class ConvertConv2DNhwcHwcf final if (!hasAllOneValues(convOp.getDilations())) return failure(); + // Ignore if marked as Winograd convolution + if (convOp->hasAttr(winogradAttr)) return failure(); + Value input = convOp.getInputs()[0]; Value filter = convOp.getInputs()[1]; Value output = convOp.getOutputs()[0]; @@ -405,6 +410,9 @@ class ConvertConv2DNchwFchw final if (!hasAllOneValues(convOp.getDilations())) return failure(); + // Ignore if marked as Winograd convolution + if (convOp->hasAttr(winogradAttr)) return failure(); + Value input = convOp.getInputs()[0]; Value filter = convOp.getInputs()[1]; Value output = convOp.getOutputs()[0]; diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h index 77dbb09135b2b..62d93d14e5aa1 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h @@ -82,6 +82,62 @@ const float A_6x6_3x3[] = { // clang-format on +//===----------------------------------------------------------------------===// +// Output tile size = 4, Kernel size = 3 +//===----------------------------------------------------------------------===// +// These constants were obtained from this paper: +// +// Lavin, A. et al (2016) Fast Algorithms for Convolution Neural Networks. +// https://openaccess.thecvf.com/content_cvpr_2016/papers/Lavin_Fast_Algorithms_for_CVPR_2016_paper.pdf +// + +// clang-format off + +const float BT_4x4_3x3[] = { + 4, 0, -5, 0, 1, 0, + 0, -4, -4, 1, 1, 0, + 0, 4, -4, -1, 1, 0, + 0, -2, -1, 2, 1, 0, + 0, 2, -1, -2, 1, 0, + 0, 4, 0, -5, 0, 1 +}; + +const float B_4x4_3x3[] = { + 4, 0, 0, 0, 0, 0, + 0, -4, 4, -2, 2, 4, + -5, -4, -4, -1, -1, 0, + 0, 1, -1, 2, -2, -5, + 1, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 1 +}; + +const float G_4x4_3x3[] = { + 1./4., 0, 0, + -1./6., -1./6., -1./6., + -1./6., 1./6., -1./6., + 1./24., 1./12., 1./6., + 1./24., -1./12., 1./6., + 0, 0, 1 +}; + +const float AT_4x4_3x3[] = { + 1, 1, 1, 1, 1, 0, + 0, 1, -1, 2, -2, 0, + 0, 1, 1, 4, 4, 0, + 0, 1, -1, 8, -8, 1 +}; + +const float A_4x4_3x3[] = { + 1, 0, 0, 0, + 1, 1, 1, 1, + 1, -1, 1, -1, + 1, 2, 4, 8, + 1, -2, 4, -8, + 0, 0, 0, 1 +}; + +// clang-format on + } // namespace Winograd } // namespace LinalgExt } // namespace IREE diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp index bece710a76bd3..3a994f1b048fd 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp @@ -24,6 +24,8 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetVector.h" +#include +#include namespace mlir { namespace iree_compiler { @@ -47,7 +49,7 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) { // TODO: Make this a user-settable parameter once we have support // for more tile sizes -static constexpr int64_t outputTileSize = 6; +static constexpr int64_t outputTileSize = 4; /// This function computes the Winograd filter transform when /// the filter is known to be a constant. Specifically, this @@ -72,30 +74,50 @@ foldFilterTransform(ArrayRef shape, int64_t inputTileSize, const int &kw = isNchw ? shape[3] : shape[1]; const int &ic = isNchw ? shape[1] : shape[2]; const int &oc = isNchw ? shape[0] : shape[3]; + //printf("Folding filter with kh = %d, kw = %d, ic = %d, oc = %d\n", kh, kw, ic, oc); const int64_t numElements = inputTileSize * inputTileSize * ic * oc; + float *alloc{nullptr}; + if (!isSplat) { + alloc = (float *) malloc(kh * kw * ic * oc * sizeof(float)); + for (int d2 = 0; d2 < ic; d2++) { + for (int d3 = 0; d3 < oc; d3++) { + for (int d4 = 0; d4 < kernelSize; d4++) { + for (int d5 = 0; d5 < kernelSize; d5++) { + int idx; + if (!isNchw) { + idx = index(d4, d5, d2, d3, kh, kw, ic, oc); + } else { + idx = index(d3, d2, d4, d5, oc, ic, kh, kw); + } + alloc[idx] = input[idx].convertToFloat(); + } + } + } + } + } SmallVector output(numElements, APFloat(0.0f)); for (int d0 = 0; d0 < inputTileSize; d0++) { for (int d1 = 0; d1 < inputTileSize; d1++) { for (int d2 = 0; d2 < ic; d2++) { for (int d3 = 0; d3 < oc; d3++) { - APFloat accum(0.0f); + float accum(0.0f); for (int d4 = 0; d4 < kernelSize; d4++) { for (int d5 = 0; d5 < kernelSize; d5++) { - APFloat ival(splatValue); + float ival{splatValue}; if (!isSplat) { if (!isNchw) { - ival = input[index(d4, d5, d2, d3, kh, kw, ic, oc)]; + ival = alloc[index(d4, d5, d2, d3, kh, kw, ic, oc)]; } else { - ival = input[index(d3, d2, d4, d5, oc, ic, kh, kw)]; + ival = alloc[index(d3, d2, d4, d5, oc, ic, kh, kw)]; } } int idx0 = index(d0, d4, inputTileSize, kernelSize); int idx1 = index(d1, d5, inputTileSize, kernelSize); - accum = accum + APFloat(G[idx0]) * ival * APFloat(G[idx1]); + accum = accum + G[idx0] * ival * G[idx1]; } } int odx = index(d0, d1, d2, d3, inputTileSize, inputTileSize, ic, oc); - output[odx] = accum; + output[odx] = APFloat(accum); if (floatType.isF16()) { bool losesInfo; output[odx].convert(APFloat::IEEEhalf(), @@ -105,6 +127,7 @@ foldFilterTransform(ArrayRef shape, int64_t inputTileSize, } } } + if (alloc) free(alloc); return DenseElementsAttr::get(outputType, output); } @@ -165,10 +188,11 @@ class FoldWinogradFilterTransform final : public OpRewritePattern { const int64_t kernelSize = kh; const int64_t inputTileSize = outputTileSize + kernelSize - 1; - DenseIntOrFPElementsAttr kernelAttr; - if (!matchPattern(kernel, m_Constant(&kernelAttr))) { + Attribute rawKernelAttr; + if (!matchPattern(kernel, m_Constant(&rawKernelAttr)) || !isa(rawKernelAttr)) { return failure(); } + DenseIntOrFPElementsAttr kernelAttr = cast(rawKernelAttr); Operation *constOp = kernel.getDefiningOp(); ShapedType type = constOp->getResult(0).getType().cast(); @@ -190,8 +214,9 @@ class FoldWinogradFilterTransform final : public OpRewritePattern { auto resultType = RankedTensorType::get(resultShape, elemType); auto foldedKernelAttr = foldFilterTransform(shape, inputTileSize, kernelSize, resultType, - IREE::LinalgExt::Winograd::G_6x6_3x3, isSplat, + IREE::LinalgExt::Winograd::G_4x4_3x3, isSplat, splatValue, nonSplatValues, elemType, isNchw); + rewriter.replaceOpWithNewOp(constOp, foldedKernelAttr); return success(); } diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp index 52d197a852831..775ee7c14219a 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeWinogradPass.cpp @@ -67,6 +67,10 @@ class ReifyWinogradInputTransform final const int64_t inputTileSize = inputOp.getInputTileSize(); const int64_t outputTileSize = inputOp.getOutputTileSize(); switch (outputTileSize) { + case 4: + B = IREE::LinalgExt::Winograd::B_4x4_3x3; + BT = IREE::LinalgExt::Winograd::BT_4x4_3x3; + break; case 6: B = IREE::LinalgExt::Winograd::B_6x6_3x3; BT = IREE::LinalgExt::Winograd::BT_6x6_3x3; @@ -235,6 +239,10 @@ class ReifyWinogradOutputTransform final const int64_t inputTileSize = outputOp.getInputTileSize(); const int64_t outputTileSize = outputOp.getOutputTileSize(); switch (outputTileSize) { + case 4: + A = IREE::LinalgExt::Winograd::A_4x4_3x3; + AT = IREE::LinalgExt::Winograd::AT_4x4_3x3; + break; case 6: A = IREE::LinalgExt::Winograd::A_6x6_3x3; AT = IREE::LinalgExt::Winograd::AT_6x6_3x3; From 740efa98649adcdeaf9d44c7ee1cec2837a6af04 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 23 Jan 2023 02:24:11 -0500 Subject: [PATCH 10/44] [API] Expose iree-opt in python for applying flow preprocessing passes --- compiler/bindings/python/CMakeLists.txt | 7 + compiler/bindings/python/IREEOptTool.c | 9 ++ .../python/iree/compiler/tools/binaries.py | 2 + .../python/iree/compiler/tools/core.py | 122 ++++++++++++++++++ 4 files changed, 140 insertions(+) create mode 100644 compiler/bindings/python/IREEOptTool.c diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt index ea89d0bedb689..c32c73f267551 100644 --- a/compiler/bindings/python/CMakeLists.txt +++ b/compiler/bindings/python/CMakeLists.txt @@ -205,6 +205,13 @@ add_iree_compiler_busybox_tool( IREECompileTool.c ) +add_iree_compiler_busybox_tool( + IREECompilerIREEOptTool + OUTPUT_NAME iree-opt + SRCS + IREEOptTool.c +) + if(TARGET lld) add_iree_compiler_busybox_tool( IREECompilerLldTool diff --git a/compiler/bindings/python/IREEOptTool.c b/compiler/bindings/python/IREEOptTool.c new file mode 100644 index 0000000000000..5c3f241332512 --- /dev/null +++ b/compiler/bindings/python/IREEOptTool.c @@ -0,0 +1,9 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/tool_entry_points_api.h" + +int main(int argc, char **argv) { return ireeOptRunMain(argc, argv); } diff --git a/compiler/bindings/python/iree/compiler/tools/binaries.py b/compiler/bindings/python/iree/compiler/tools/binaries.py index e2d51c7331818..f73dac34ea5c3 100644 --- a/compiler/bindings/python/iree/compiler/tools/binaries.py +++ b/compiler/bindings/python/iree/compiler/tools/binaries.py @@ -30,6 +30,7 @@ _BUILTIN_TOOLS = [ "iree-compile", + "iree-opt", "iree-lld", ] @@ -42,6 +43,7 @@ # options. "iree-compile": "iree.tools.core", "iree-lld": "iree.tools.core", + "iree-opt": "iree.tools.core", "iree-import-tflite": "iree.tools.tflite", "iree-import-tf": "iree.tools.tf", } diff --git a/compiler/bindings/python/iree/compiler/tools/core.py b/compiler/bindings/python/iree/compiler/tools/core.py index 46ab6fbb30f54..84bc978d44a42 100644 --- a/compiler/bindings/python/iree/compiler/tools/core.py +++ b/compiler/bindings/python/iree/compiler/tools/core.py @@ -25,6 +25,8 @@ "CompilerOptions", "InputType", "OutputFormat", + "preprocess_file", + "preprocess_str", ] # Default testing backend for invoking the compiler. @@ -317,3 +319,123 @@ def query_available_targets(): target_backends = [target for target in target_backends if target] return target_backends + + +# Preprocessing for SHARK (for now simply exposes iree-opt) + + +def build_opt_command_line( + input_file: str, tfs: TempFileSaver, options: CompilerOptions +) -> List[str]: + """Builds a command line for applying specified patterns. + + Args: + input_file: The input file name. + tfs: TempFileSaver. + options: Compiler options. + Returns: + List of strings of command line. + """ + iree_opt = find_tool("iree-opt") + cl = [ + iree_opt, + input_file, + ] + + # Output file. + if options.output_file: + cl.append(f"-o={options.output_file}") + + # Tool paths. + lld_path = find_tool("iree-lld") + cl.append(f"--iree-llvm-embedded-linker-path={lld_path}") + + crash_reproducer_path = tfs.alloc_optional( + "core-reproducer.mlir", export_as=options.crash_reproducer_path + ) + if crash_reproducer_path: + cl.append(f"--mlir-pass-pipeline-crash-reproducer={crash_reproducer_path}") + + cl.extend(options.extra_args) + print(cl) + return cl + + +def preprocess_file(input_file: str, **kwargs): + """Invokes iree-opt on an input file. + + Args: + input_file: File containing MLIR assembly to compile. + **kwargs: Keyword arguments corresponding to CompilerOptions. + Returns: + Either a byte buffer of the compiled content or None if output_file + was specified in the options. + """ + with TempFileSaver.implicit() as tfs: + options = CompilerOptions(**kwargs) + retained_output_file = tfs.alloc_optional( + "core-output.bin", export_as=options.output_file + ) + if options.output_file: + options.output_file = retained_output_file + cl = build_opt_command_line(input_file, tfs, options) + + # Save a temp file with the command line. + retained_cl = tfs.alloc_optional("core-command-line.txt") + if retained_cl: + with open(retained_cl, "wt") as f: + f.write(" ".join(cl)) + + result = invoke_immediate(cl) + if options.output_file: + return None + # Output as string needs to write to the retained output file itself. + if retained_output_file: + with open(retained_output_file, "wb") as f: + f.write(result) + return result + + +def preprocess_str(input_str: Union[str, bytes], **kwargs): + """Invokes the IREE compiler with an input string. + + Args: + input_str: MLIR assembly to parse/compile (str or bytes). + **kwargs: Keyword arguments corresponding to CompilerOptions. + Returns: + Either a byte buffer of the compiled content or None if output_file + was specified in the options. + """ + with TempFileSaver.implicit() as tfs: + retained_input_file = tfs.alloc_optional("core-input.mlir") + if retained_input_file: + with open( + retained_input_file, "wt" if isinstance(input_str, str) else "wb" + ) as f: + f.write(input_str) + options = CompilerOptions(**kwargs) + retained_output_file = tfs.alloc_optional( + "core-output.bin", export_as=options.output_file + ) + if options.output_file: + options.output_file = retained_output_file + cl = build_opt_command_line("-", tfs, options) + input_bytes = ( + input_str.encode("utf-8") if isinstance(input_str, str) else input_str + ) + + # Save a temp file with the command line. + retained_cl = tfs.alloc_optional("core-command-line.txt") + if retained_cl: + with open(retained_cl, "wt") as f: + f.write(" ".join(cl)) + + result = invoke_immediate(cl, immediate_input=input_bytes) + if options.output_file: + return None + + # Output as string needs to write to the retained output file itself. + if retained_output_file: + with open(retained_output_file, "wb") as f: + f.write(result) + return result From 91fb38067a597c8802347d6c02ecc10e5c4f9bb6 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Tue, 28 Mar 2023 11:37:38 -0700 Subject: [PATCH 11/44] [COMPILER] Add a plugin to split MLIR functions Add pass to insert markers for function bisecting. Add pass to outline marked operation ranges into separate functions. --- .../iree_compiler_plugin_group.cmake | 7 + experimental/split_mlir/lit.cfg.py | 41 +++ experimental/split_mlir/src/CMakeLists.txt | 20 ++ .../split_mlir/src/iree/CMakeLists.txt | 7 + .../src/iree/split_mlir/CMakeLists.txt | 52 +++ .../src/iree/split_mlir/MarkBisectPassImpl.h | 87 +++++ .../split_mlir/OutlineFunctionsPassImpl.h | 299 ++++++++++++++++++ .../split_mlir/src/iree/split_mlir/Passes.cpp | 8 + .../split_mlir/src/iree/split_mlir/Passes.h | 37 +++ .../split_mlir/src/iree/split_mlir/Passes.td | 79 +++++ .../iree/split_mlir/PluginRegistration.cpp | 32 ++ .../src/iree/split_mlir/test/CMakeLists.txt | 27 ++ .../split_mlir/test/function_outlining.mlir | 102 ++++++ .../src/iree/split_mlir/test/mark_bisect.mlir | 68 ++++ 14 files changed, 866 insertions(+) create mode 100644 experimental/split_mlir/iree_compiler_plugin_group.cmake create mode 100644 experimental/split_mlir/lit.cfg.py create mode 100644 experimental/split_mlir/src/CMakeLists.txt create mode 100644 experimental/split_mlir/src/iree/CMakeLists.txt create mode 100644 experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt create mode 100644 experimental/split_mlir/src/iree/split_mlir/MarkBisectPassImpl.h create mode 100644 experimental/split_mlir/src/iree/split_mlir/OutlineFunctionsPassImpl.h create mode 100644 experimental/split_mlir/src/iree/split_mlir/Passes.cpp create mode 100644 experimental/split_mlir/src/iree/split_mlir/Passes.h create mode 100644 experimental/split_mlir/src/iree/split_mlir/Passes.td create mode 100644 experimental/split_mlir/src/iree/split_mlir/PluginRegistration.cpp create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/CMakeLists.txt create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/function_outlining.mlir create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/mark_bisect.mlir diff --git a/experimental/split_mlir/iree_compiler_plugin_group.cmake b/experimental/split_mlir/iree_compiler_plugin_group.cmake new file mode 100644 index 0000000000000..5ec79cd633920 --- /dev/null +++ b/experimental/split_mlir/iree_compiler_plugin_group.cmake @@ -0,0 +1,7 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src split_mlir) diff --git a/experimental/split_mlir/lit.cfg.py b/experimental/split_mlir/lit.cfg.py new file mode 100644 index 0000000000000..eba45391c230b --- /dev/null +++ b/experimental/split_mlir/lit.cfg.py @@ -0,0 +1,41 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Lint for undefined variables is disabled as config is not defined inside this +# file, instead config is injected by way of evaluating runlit.cfg.py from +# runlit.site.cfg.py which in turn is evaluated by lit.py. +# pylint: disable=undefined-variable + +import os +import tempfile + +import lit.formats + +config.name = "IREE" +config.suffixes = [".mlir", ".txt"] +config.test_format = lit.formats.ShTest(execute_external=True) + +# Forward all IREE environment variables, as well as some passthroughs. +# Note: env vars are case-insensitive on Windows, so check matches carefully. +# https://stackoverflow.com/q/7797269 +passthrough_env_vars = [ + # The Vulkan loader uses this + "VK_ICD_FILENAMES", + # WindowsLinkerTool uses these from vcvarsall + "VCTOOLSINSTALLDIR", + "UNIVERSALCRTSDKDIR", + "UCRTVERSION" +] +config.environment.update({ + k: v + for k, v in os.environ.items() + if k.startswith("IREE_") or k in passthrough_env_vars +}) + +# Use the most preferred temp directory. +config.test_exec_root = (os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") or + os.environ.get("TEST_TMPDIR") or + os.path.join(tempfile.gettempdir(), "lit")) diff --git a/experimental/split_mlir/src/CMakeLists.txt b/experimental/split_mlir/src/CMakeLists.txt new file mode 100644 index 0000000000000..da48cd6216517 --- /dev/null +++ b/experimental/split_mlir/src/CMakeLists.txt @@ -0,0 +1,20 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_cc_library( + NAME + defs + INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_BINARY_DIR} + PUBLIC +) + +# Configures all iree_cc_* targets to take this implicit dep, +# which provides common includes and copts for the tree. +set(IREE_IMPLICIT_DEFS_CC_DEPS iree::experimental::split_mlir::src::defs) + +iree_add_all_subdirs() diff --git a/experimental/split_mlir/src/iree/CMakeLists.txt b/experimental/split_mlir/src/iree/CMakeLists.txt new file mode 100644 index 0000000000000..33551b5769745 --- /dev/null +++ b/experimental/split_mlir/src/iree/CMakeLists.txt @@ -0,0 +1,7 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() diff --git a/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt b/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt new file mode 100644 index 0000000000000..63627bf3f76c4 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt @@ -0,0 +1,52 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +iree_tablegen_library( + NAME + PassesIncGen + TD_FILE + "Passes.td" + OUTS + --gen-pass-decls Passes.h.inc +) + +iree_cc_library( + NAME + split_mlir_lib + HDRS + "Passes.h" + "Passes.h.inc" + SRCS + "Passes.cpp" + DEPS + ::PassesIncGen + MLIRFuncDialect + MLIRIR + MLIRPass + PUBLIC +) + +iree_cc_library( + NAME + registration + SRCS + "PluginRegistration.cpp" + DEPS + ::split_mlir_lib + MLIRIR + MLIRPass + iree::compiler::PluginAPI + PUBLIC +) + +iree_compiler_register_plugin( + PLUGIN_ID + split_mlir + TARGET + ::registration +) diff --git a/experimental/split_mlir/src/iree/split_mlir/MarkBisectPassImpl.h b/experimental/split_mlir/src/iree/split_mlir/MarkBisectPassImpl.h new file mode 100644 index 0000000000000..1224dcd4818ae --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/MarkBisectPassImpl.h @@ -0,0 +1,87 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/split_mlir/Passes.h" +#include "llvm/ADT/SmallSet.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace iree { +namespace split_mlir { + +#define GEN_PASS_DEF_MARKBISECT +#include "iree/split_mlir/Passes.h.inc" // IWYU pragma: export + +namespace { + +void markRangeFirst(Operation& op, OpBuilder& builder) { + op.setAttr("outline_range_first", builder.getUnitAttr()); +} + +void markRangeLast(Operation& op, OpBuilder& builder) { + op.setAttr("outline_range_last", builder.getUnitAttr()); +} + +struct MarkBisectPass : public impl::MarkBisectBase { + using MarkBisectBase::MarkBisectBase; + + LogicalResult initialize(MLIRContext* context) override { + functionsSet.insert(functions.begin(), functions.end()); + return LogicalResult::success(); + } + + void runOnOperation() override { + mlir::func::FuncOp funcOp = getOperation(); + if (!functionsSet.contains(funcOp.getSymName())) { + return; + } + if (funcOp.getBody().getBlocks().size() > 1) { + return signalPassFailure(); + } + Block& entryBlock = funcOp.getBody().front(); + if (entryBlock.getOperations().size() < 3) { + // Degenerate case. Needs at least 1 op for each half + the return op. + return; + } + size_t opsCount = entryBlock.getOperations().size(); + size_t cutOpIndex = (opsCount - 1) / 2; + OpBuilder builder(&getContext()); + // Ranges are inclusive, [first, last]. + auto firstHalfLastOp = entryBlock.begin(); + std::advance(firstHalfLastOp, cutOpIndex - 1); + markRangeFirst(entryBlock.front(), builder); + markRangeLast(*firstHalfLastOp, builder); + auto secondHalfFirstOp = firstHalfLastOp; + std::advance(secondHalfFirstOp, 1); + markRangeFirst(*secondHalfFirstOp, builder); + auto secondHalfLastOp = entryBlock.end(); + // Take operation that is just before the return operation. + std::advance(secondHalfLastOp, -2); + markRangeLast(*secondHalfLastOp, builder); + } + + private: + llvm::SmallSet functionsSet; +}; + +} // namespace + +std::unique_ptr> createMarkBisectPass() { + return std::make_unique(); +} + +} // namespace split_mlir +} // namespace iree +} // namespace mlir diff --git a/experimental/split_mlir/src/iree/split_mlir/OutlineFunctionsPassImpl.h b/experimental/split_mlir/src/iree/split_mlir/OutlineFunctionsPassImpl.h new file mode 100644 index 0000000000000..783520e7fdb9b --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/OutlineFunctionsPassImpl.h @@ -0,0 +1,299 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include + +#include "iree/split_mlir/Passes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace iree { +namespace split_mlir { + +#define GEN_PASS_DEF_OUTLINEFUNCTIONS +#include "iree/split_mlir/Passes.h.inc" // IWYU pragma: export + +namespace { + +// Collect all operation ranges that are marked for outlining. +// The begining of a range is marked with the outline_range_first attribute. +// The last operation of a range is marked with the outline_range_last attribue. +// Example: +// %0 = arith.addi %arg0, %arg1 {outline_range_first} : i32 +// %1 = arith.addi %arg2, %arg3 : i32 +// %2 = arith.muli %arg3, %arg4 {outline_range_last} : i32 +// The outline range will consist of the 3 operations. +LogicalResult getOutlineOpRanges( + Block& block, SmallVector, 4>& res) { + bool isInOutliningRange = false; + Block::iterator rangeBegin; + for (Block::iterator opIt = block.begin(); opIt != block.end(); ++opIt) { + if (opIt->hasAttr("outline_range_first")) { + if (isInOutliningRange) { + return LogicalResult::failure(); + } + isInOutliningRange = true; + rangeBegin = opIt; + } + + if (opIt->hasAttr("outline_range_last")) { + if (!isInOutliningRange) { + return LogicalResult::failure(); + } + isInOutliningRange = false; + res.emplace_back(rangeBegin, std::next(opIt)); + } + } + if (isInOutliningRange) { + // No matching closing marker outline_range_last. + return LogicalResult::failure(); + } + + return LogicalResult::success(); +} + +// Return all values that are an operand of some of the given ops that are +// produced by other ops. Also return all values that are a result of some of +// the given ops and have uses outside the ops range. +std::pair, SmallVector> +getOperandsAndResultsForIsolation(iterator_range opRange, + const SmallPtrSet& opsSet) { + SmallVector operands; + SmallVector results; + SmallPtrSet operandsSet; + SmallPtrSet resultsSet; + for (Operation& op : opRange) { + for (Value operand : op.getOperands()) { + if (!opsSet.contains(operand.getDefiningOp())) { + auto insertionResult = operandsSet.insert(operand); + if (insertionResult.second) { + operands.push_back(operand); + } + } + } + for (OpResult result : op.getResults()) { + for (OpOperand operand : result.getUsers()) { + if (!opsSet.contains(operand.getOwner())) { + auto insertionResult = resultsSet.insert(result); + if (insertionResult.second) { + results.push_back(result); + } + break; + } + } + } + } + return {operands, results}; +} + +template +void replaceValueUsesWithNewBlockArguments(ValueIt valuesBegin, + ValueIt valuesEnd, Block& block) { + for (ValueIt valIt = valuesBegin; valIt != valuesEnd; ++valIt) { + block.addArgument(valIt->getType(), valIt->getLoc()); + BlockArgument& blockArg = block.getArguments().back(); + valIt->replaceUsesWithIf(blockArg, [&block](OpOperand& operand) { + return operand.getOwner()->getBlock() == █ + }); + } +} + +void addBlockReturn(Block& block, ValueRange operands, OpBuilder& builder) { + func::ReturnOp returnOp = + builder.create(builder.getUnknownLoc(), operands); + block.push_back(returnOp); +} + +void moveOpsIntoBlock(iterator_range opRange, Block& block) { + // Put ops into another container because opRange will be invalidated during + // removal. + SmallVector ops; + std::transform(opRange.begin(), opRange.end(), std::back_inserter(ops), + [](Operation& op) { return &op; }); + for (Operation* op : ops) { + op->moveBefore(&block, block.end()); + } +} + +void moveBlock(Region& srcRegion, Region& destRegion, + Region::iterator srcBlockIt, Region::iterator destBlockIt) { + Block* block = srcRegion.getBlocks().remove(srcBlockIt); + destRegion.getBlocks().insert(destBlockIt, block); +} + +bool isAncestorOfBlock(Operation* op, Block* block) { + // Walk up the operation hierarchy and check each block. + while (op != nullptr) { + if (op->getBlock() == block) { + return true; + } + op = op->getParentOp(); + } + return false; +} + +template +void substititeUses(OriginalOpResultsIt originalBegin, + OriginalOpResultsIt originalEnd, NewOpResultsIt newBegin, + NewOpResultsIt newEnd, Block& excludedBlock) { + assert(std::distance(originalBegin, originalEnd) == + std::distance(newBegin, newEnd)); + auto newIt = newBegin; + for (auto originalIt = originalBegin; originalIt != originalEnd; + ++originalIt, ++newIt) { + originalIt->replaceUsesWithIf(*newIt, [&excludedBlock](OpOperand& operand) { + return !isAncestorOfBlock(operand.getOwner(), &excludedBlock); + }); + } +} + +// All operations in the range `opRange` are moved into a new function with name +// `name`. The resulting function is put inside `moduleOp` and is properly +// isolated from above. This does not insert a call to the new function in place +// of the moved operations. +func::FuncOp createFunctionFromOps(iterator_range opRange, + StringRef name, ModuleOp moduleOp, + SmallVector& rangeOperands, + SmallVector& rangeResults, + OpBuilder& builder) { + Region& region = *opRange.begin()->getParentRegion(); + Block& dstBlock = region.emplaceBlock(); + moveOpsIntoBlock(opRange, dstBlock); + replaceValueUsesWithNewBlockArguments(rangeOperands.begin(), + rangeOperands.end(), dstBlock); + addBlockReturn(dstBlock, + ArrayRef(rangeResults.begin(), rangeResults.end()), + builder); + func::FuncOp funcOp = builder.create( + builder.getUnknownLoc(), name, + FunctionType::get(builder.getContext(), dstBlock.getArgumentTypes(), + dstBlock.back().getOperandTypes())); + moduleOp.getBodyRegion().getBlocks().front().push_back(funcOp); + moveBlock(region, funcOp.getBody(), std::prev(region.end()), + funcOp.getBody().end()); + + return funcOp; +} + +void createCall(func::FuncOp funcOp, Block& block, Block::iterator pos, + SmallVector& rangeOperands, + SmallVector& rangeResults, OpBuilder& builder) { + func::CallOp callOp = builder.create( + builder.getUnknownLoc(), funcOp, + ArrayRef(rangeOperands.begin(), rangeOperands.end())); + block.getOperations().insert(pos, callOp); + substititeUses(rangeResults.begin(), rangeResults.end(), + callOp.getResults().begin(), callOp.getResults().end(), + funcOp.getBody().back()); +} + +std::optional outlineOpRange( + iterator_range opRange, StringRef name, ModuleOp moduleOp, + OpBuilder& builder) { + if (opRange.empty()) { + return std::nullopt; + } + + SmallPtrSet opsSet; + for (Operation& op : opRange) { + opsSet.insert(&op); + } + SmallVector rangeOperands; + SmallVector rangeResults; + std::tie(rangeOperands, rangeResults) = + getOperandsAndResultsForIsolation(opRange, opsSet); + Block& srcBlock = *opRange.begin()->getBlock(); + + func::FuncOp funcOp = createFunctionFromOps( + opRange, name, moduleOp, rangeOperands, rangeResults, builder); + createCall(funcOp, srcBlock, opRange.end(), rangeOperands, rangeResults, + builder); + + return funcOp; +} + +std::string getOutlinedFuncName(StringRef prefix, int blockIndex, + int outlineRangeIndex) { + return (Twine(prefix) + "_outline_" + Twine(blockIndex) + "_" + + Twine(outlineRangeIndex)) + .str(); +} + +void removeOutlineMarkers(iterator_range opRange) { + if (opRange.empty()) { + return; + } + opRange.begin()->removeAttr("outline_range_first"); + std::prev(opRange.end())->removeAttr("outline_range_last"); +} + +// Each marked operation range in `funcOp` is outlined into a new function. +// A call to the new function is inserted in place of the outlined operations. +LogicalResult outlineOpRanges(func::FuncOp funcOp, ModuleOp moduleOp, + OpBuilder& builder) { + Region& funcBody = funcOp.getFunctionBody(); + SmallVector, 4> outlineRanges; + for (auto blockIt : llvm::enumerate(funcBody.getBlocks())) { + outlineRanges.clear(); + if (failed(getOutlineOpRanges(blockIt.value(), outlineRanges))) { + return LogicalResult::failure(); + } + for (auto rangeIt : llvm::enumerate(outlineRanges)) { + removeOutlineMarkers(rangeIt.value()); + std::string name = getOutlinedFuncName(funcOp.getSymName(), + blockIt.index(), rangeIt.index()); + outlineOpRange(rangeIt.value(), name, moduleOp, builder); + } + } + + return LogicalResult::success(); +} + +struct OutlineFunctionsPass + : public impl::OutlineFunctionsBase { + using OutlineFunctionsBase::OutlineFunctionsBase; + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + Block& moduleBlock = *moduleOp.getBody(); + OpBuilder builder(&getContext()); + // Get all functions since we are going to insert new ones + // that we don't want to iterate over. + SmallVector funcOps( + moduleBlock.getOps().begin(), + moduleBlock.getOps().end()); + for (func::FuncOp op : funcOps) { + if (failed(outlineOpRanges(op, moduleOp, builder))) { + return signalPassFailure(); + } + } + } +}; + +} // namespace + +std::unique_ptr> createOutlineFunctionsPass() { + return std::make_unique(); +} + +} // namespace split_mlir +} // namespace iree +} // namespace mlir diff --git a/experimental/split_mlir/src/iree/split_mlir/Passes.cpp b/experimental/split_mlir/src/iree/split_mlir/Passes.cpp new file mode 100644 index 0000000000000..50fd44ad024d2 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/Passes.cpp @@ -0,0 +1,8 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/split_mlir/MarkBisectPassImpl.h" +#include "iree/split_mlir/OutlineFunctionsPassImpl.h" diff --git a/experimental/split_mlir/src/iree/split_mlir/Passes.h b/experimental/split_mlir/src/iree/split_mlir/Passes.h new file mode 100644 index 0000000000000..039d007273652 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/Passes.h @@ -0,0 +1,37 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_SPLIT_MLIR_TRANSFORM_PASSES_H_ +#define IREE_SPLIT_MLIR_TRANSFORM_PASSES_H_ + +#include + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class ModuleOp; +namespace func { +class FuncOp; +} // namespace func + +namespace iree { +namespace split_mlir { + +#define GEN_PASS_DECL +#include "iree/split_mlir/Passes.h.inc" // IWYU pragma: export + +std::unique_ptr> createOutlineFunctionsPass(); +std::unique_ptr> createMarkBisectPass(); + +#define GEN_PASS_REGISTRATION +#include "iree/split_mlir/Passes.h.inc" // IWYU pragma: export + +} // namespace split_mlir +} // namespace iree +} // namespace mlir + +#endif // IREE_SPLIT_MLIR_TRANSFORM_PASSES_H_ diff --git a/experimental/split_mlir/src/iree/split_mlir/Passes.td b/experimental/split_mlir/src/iree/split_mlir/Passes.td new file mode 100644 index 0000000000000..f07f37dfd5a6d --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/Passes.td @@ -0,0 +1,79 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_SPLIT_MLIR_TRANSFORM_PASSES +#define IREE_SPLIT_MLIR_TRANSFORM_PASSES + +include "mlir/Pass/PassBase.td" + +def OutlineFunctions : + Pass<"iree-outline-functions", "mlir::ModuleOp"> { + let summary = "Outline operations in separate function(s)."; + let description = [{ + Marked operation ranges in a function block are outlined/moved into new functions. + In place of an outlined operations range is inserted a call to the new function. + The resulting function is equivalent to the original. + The ranges for outlining must be marked with the attributes + `outline_range_first`, `outline_range_last`. + + Example: + ```mlir + func.func @f(%arg0: i32, %arg1: i32) -> i32 { + %0 = arith.addi %arg0, %arg1 {outline_range_first} : i32 + %1 = arith.muli %0, %0 : i32 + %2 = arith.muli %1, %1 {outline_range_last} : i32 + %3 = arith.addi %2, %2 : i32 + return %3 : i32 + } + ``` + + The above MLIR will be transformed to: + ```mlir + func.func @f(%arg0: i32, %arg1: i32) -> i32 { + %0 = call @f_outline_0_0(%arg0, %arg1) : (i32, i32) -> i32 + %1 = arith.addi %0, %0 : i32 + return %1 : i32 + } + func.func @f_outline_0_0(%arg0: i32, %arg1: i32) -> i32 { + %0 = arith.addi %arg0, %arg1 : i32 + %1 = arith.muli %0, %0 : i32 + %2 = arith.muli %1, %1 : i32 + return %2 : i32 + } + ``` + + The pass will fail if there is branching to other function blocks + inside a marked operation range. + }]; + let constructor = "mlir::iree::split_mlir::createOutlineFunctionsPass()"; + let dependentDialects = ["mlir::func::FuncDialect"]; +} + +def MarkBisect : Pass<"iree-mark-bisect", "mlir::func::FuncOp"> { + let summary = "Mark operations in function(s) for outlining with bisect strategy."; + let description = [{ + Each function's entry block is bisected, + such that each piece has balanced number of ops. + The two pieces are marked with attributes `outline_range_first` and + `outline_range_last`. These markings surve as input to the `OutlineFunctions` pass. + + Example: + ```bash + iree-opt \ + --iree-plugin=split_mlir \ + --pass-pipeline="builtin.module(func.func(iree-mark-bisect{functions=f,g}))" + my.mlir + ``` + + }]; + let constructor = "mlir::iree::split_mlir::createMarkBisectPass()"; + let options = [ + ListOption<"functions", "functions", "std::string", "List of functions to bisect."> + ]; + let dependentDialects = ["mlir::func::FuncDialect"]; +} + +#endif // IREE_SPLIT_MLIR_TRANSFORM_PASSES diff --git a/experimental/split_mlir/src/iree/split_mlir/PluginRegistration.cpp b/experimental/split_mlir/src/iree/split_mlir/PluginRegistration.cpp new file mode 100644 index 0000000000000..762bfe327b7d3 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/PluginRegistration.cpp @@ -0,0 +1,32 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/PluginAPI/Client.h" +#include "iree/split_mlir/Passes.h" + + +using namespace mlir; +using namespace mlir::iree_compiler; + +namespace { + +struct SplitMlirOptions { + void bindOptions(OptionsBinder &binder) {} +}; + +struct SplitMlirSession : public PluginSession { + static void registerPasses() { + iree::split_mlir::registerPasses(); + } +}; +} // namespace + +IREE_DEFINE_COMPILER_OPTION_FLAGS(SplitMlirOptions); + +extern "C" bool iree_register_compiler_plugin_split_mlir(PluginRegistrar *registrar) { + registrar->registerPlugin("split_mlir"); + return true; +} diff --git a/experimental/split_mlir/src/iree/split_mlir/test/CMakeLists.txt b/experimental/split_mlir/src/iree/split_mlir/test/CMakeLists.txt new file mode 100644 index 0000000000000..27c8e39e0be06 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + SRCS + "mark_bisect.mlir" + TOOLS + FileCheck + iree-opt +) + +iree_lit_test_suite( + NAME + lit + SRCS + "function_outlining.mlir" + TOOLS + FileCheck + iree-opt +) diff --git a/experimental/split_mlir/src/iree/split_mlir/test/function_outlining.mlir b/experimental/split_mlir/src/iree/split_mlir/test/function_outlining.mlir new file mode 100644 index 0000000000000..cb8df4596192a --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/function_outlining.mlir @@ -0,0 +1,102 @@ +// RUN: iree-opt \ +// RUN: --split-input-file \ +// RUN: --iree-plugin=split_mlir \ +// RUN: --pass-pipeline="builtin.module(iree-outline-functions)" %s \ +// RUN: | FileCheck --dump-input-context=100 %s + +// Outline op that does not take any arguments and is not used anywhere. +// CHECK-LABEL: func.func @no_args_and_result +func.func @no_args_and_result() { +// CHECK: call @no_args_and_result_outline_0_0() : () -> () + %cts1 = mhlo.constant {outline_range_first, outline_range_last} dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: {{return$}} + return +} +// CHECK-LABEL: func.func @no_args_and_result_outline_0_0() +// CHECK: mhlo.constant dense<{{.+}}> : tensor<2xf32> +// CHECK-NOT: outline_range_first +// CHECK-NOT: outline_range_last +// CHECK-NEXT: {{return$}} + +// ----- + +// Outline an op that takes one argument and has one result that is used. +// CHECK-LABEL: func.func @one_arg_and_one_result +// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xf32>) -> tensor<2xf32> +func.func @one_arg_and_one_result(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK-NEXT: [[RES0:%.+]] = call @one_arg_and_one_result_outline_0_0([[ARG0]]) + %res = mhlo.cosine %arg0 {outline_range_first, outline_range_last} : tensor<2xf32> +// CHECK-NEXT: return [[RES0]] : tensor<2xf32> + return %res : tensor<2xf32> +} +// CHECK-LABEL: func.func @one_arg_and_one_result_outline_0_0 +// CHECK-SAME: ([[ARG1:%.+]]: tensor<2xf32>) -> tensor<2xf32> +// CHECK-NEXT: [[RES1:%.+]] = mhlo.cosine [[ARG1]] : tensor<2xf32> +// CHECK-NEXT: return [[RES1]] : tensor<2xf32> + +// ----- + +// Multiple ops in a range with multiple arguments and results. +// CHECK-LABEL: func.func @multiple_ops +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32) +func.func @multiple_ops(%arg0: i32, %arg1: i32) -> (i32, i32) { +// CHECK-NEXT: [[RES0:%.+]]:2 = call @multiple_ops_outline_0_0([[ARG0]], [[ARG1]]) : (i32, i32) -> (i32, i32) + %add = arith.addi %arg0, %arg0 {outline_range_first} : i32 + %mul = arith.muli %add, %arg1 {outline_range_last} : i32 +// CHECK-NEXT: return [[RES0]]#0, [[RES0]]#1 : i32, i32 + return %add, %mul : i32, i32 +} +// CHECK-LABEL: func.func @multiple_ops_outline_0_0 +// CHECK-SAME: ([[ARG10:%.+]]: i32, [[ARG11:%.+]]: i32) -> (i32, i32) +// CHECK-NEXT: [[ADD:%.+]] = arith.addi [[ARG10]], [[ARG10]] : i32 +// CHECK-NEXT: [[MUL:%.+]] = arith.muli [[ADD]], [[ARG11]] : i32 +// CHECK-NEXT: return [[ADD]], [[MUL]] : i32, i32 + +// ----- + +// Outline multiple ranges in the same function. +// CHECK-LABEL: func.func @multiple_ranges_in_same_func +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32) +func.func @multiple_ranges_in_same_func(%arg0: i32, %arg1: i32) -> (i32, i32) { +// CHECK-NEXT: [[ADD:%.+]] = call @multiple_ranges_in_same_func_outline_0_0([[ARG0]]) : (i32) -> i32 + %add = arith.addi %arg0, %arg0 {outline_range_first, outline_range_last} : i32 +// CHECK-NEXT: [[MUL:%.+]] = call @multiple_ranges_in_same_func_outline_0_1([[ADD]], [[ARG1]]) : (i32, i32) -> i32 + %mul = arith.muli %add, %arg1 {outline_range_first, outline_range_last} : i32 +// CHECK-NEXT: return [[ADD]], [[MUL]] : i32, i32 + return %add, %mul : i32, i32 +} +// CHECK-LABEL: func.func @multiple_ranges_in_same_func_outline_0_0 +// CHECK-SAME: ([[ARG10:%.+]]: i32) -> i32 +// CHECK-NEXT: [[ADD1:%.+]] = arith.addi [[ARG10]], [[ARG10]] : i32 +// CHECK-NEXT: return [[ADD1]] : i32 +// CHECK-LABEL: func.func @multiple_ranges_in_same_func_outline_0_1 +// CHECK-SAME: ([[ARG20:%.+]]: i32, [[ARG21:%.+]]: i32) -> i32 +// CHECK-NEXT: [[MUL2:%.+]] = arith.muli [[ARG20]], [[ARG21]] : i32 +// CHECK-NEXT: return [[MUL2]] : i32 + +// ----- + +// Outline multiple ranges in different blocks. +// CHECK-LABEL: func.func @multiple_ranges_in_different_blocks +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 +func.func @multiple_ranges_in_different_blocks(%arg0: i32, %arg1: i32) -> i32 { +// CHECK-NEXT: [[ADD:%.+]] = call @multiple_ranges_in_different_blocks_outline_0_0([[ARG0]]) : (i32) -> i32 + %add = arith.addi %arg0, %arg0 {outline_range_first, outline_range_last} : i32 +// CHECK-NEXT: cf.br ^bb1([[ARG1]] : i32) + cf.br ^bb1(%arg1 : i32) +// CHECK-NEXT: ^bb1 +// CHECK-SAME: ([[ARG2:%.+]]: i32) +^bb1 (%arg2: i32): +// CHECK-NEXT: [[MUL:%.+]] = call @multiple_ranges_in_different_blocks_outline_1_0([[ADD]], [[ARG2]]) : (i32, i32) -> i32 + %mul = arith.muli %add, %arg2 {outline_range_first, outline_range_last} : i32 +// CHECK-NEXT: return [[MUL]] : i32 + return %mul : i32 +} +// CHECK-LABEL: func.func @multiple_ranges_in_different_blocks_outline_0_0 +// CHECK-SAME: ([[ARG10:%.+]]: i32) -> i32 +// CHECK-NEXT: [[ADD1:%.+]] = arith.addi [[ARG10]], [[ARG10]] : i32 +// CHECK-NEXT: return [[ADD1]] : i32 +// CHECK-LABEL: func.func @multiple_ranges_in_different_blocks_outline_1_0 +// CHECK-SAME: ([[ARG20:%.+]]: i32, [[ARG21:%.+]]: i32) -> i32 +// CHECK-NEXT: [[MUL2:%.+]] = arith.muli [[ARG20]], [[ARG21]] : i32 +// CHECK-NEXT: return [[MUL2]] : i32 diff --git a/experimental/split_mlir/src/iree/split_mlir/test/mark_bisect.mlir b/experimental/split_mlir/src/iree/split_mlir/test/mark_bisect.mlir new file mode 100644 index 0000000000000..3a4cb72424a71 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/mark_bisect.mlir @@ -0,0 +1,68 @@ +// RUN: iree-opt \ +// RUN: --split-input-file \ +// RUN: --iree-plugin=split_mlir \ +// RUN: --pass-pipeline="builtin.module(func.func(iree-mark-bisect{functions=two_ops,too_few_ops,multiple_ops}))" %s \ +// RUN: | FileCheck %s + +// Each operation is marked as separate range. +// CHECK-LABEL: func.func @two_ops +func.func @two_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: mhlo.constant +// CHECK-DAG: outline_range_first +// CHECK-DAG: outline_range_last + %cts1 = mhlo.constant dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: mhlo.add +// CHECK-DAG: outline_range_first +// CHECK-DAG: outline_range_last + %res = mhlo.add %arg0, %cts1 : tensor<2xf32> + return %res : tensor<2xf32> +} + +// ----- + +// Degenerate case with too few ops should not mark enything. +// CHECK-LABEL: func.func @too_few_ops +func.func @too_few_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: mhlo.constant +// CHECK-NOT: outline_range_first +// CHECK-NOT: outline_range_last + %cts1 = mhlo.constant dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: return +// CHECK-NOT: outline_range_first +// CHECK-NOT: outline_range_last + return %cts1 : tensor<2xf32> +} + +// ----- + +// Multiple ops per range. +// CHECK-LABEL: func.func @multiple_ops +func.func @multiple_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: outline_range_first +// CHECK-SAME: dense<1.000000e+00> + %cts1 = mhlo.constant dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: outline_range_last +// CHECK-SAME: dense<2.000000e+00> + %cts2 = mhlo.constant dense<2.000000e+00> : tensor<2xf32> +// CHECK-NEXT: outline_range_first +// CHECK-SAME: dense<3.000000e+00> + %cts3 = mhlo.constant dense<3.000000e+00> : tensor<2xf32> +// CHECK-NEXT: outline_range_last +// CHECK-SAME: dense<4.000000e+00> + %cts4 = mhlo.constant dense<4.000000e+00> : tensor<2xf32> +// CHECK-NEXT: return + return %cts1 : tensor<2xf32> +} + +// ----- + +// Non-listed functions should not be marked. +// CHECK-LABEL: func.func @function_not_to_mark +func.func @function_not_to_mark(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK-NOT: outline_range_first +// CHECK-NOT: outline_range_last + %cts1 = mhlo.constant dense<1.000000e+00> : tensor<2xf32> + %res = mhlo.add %arg0, %cts1 : tensor<2xf32> +// CHECK: return + return %res : tensor<2xf32> +} From 767c851894d17027c010a9454692a424a20b82db Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Wed, 5 Apr 2023 08:39:26 -0700 Subject: [PATCH 12/44] [split_mlir] Add operation list extraction and execution for IREE Expose a Python binding that has extraction of an operation list from an MLIR file. This list is then used to execute with IREE the entry MLIR while resolving calls to functions in other MLIR files. --- .../src/iree/split_mlir/CMakeLists.txt | 28 +++ .../src/iree/split_mlir/OperationListImpl.h | 181 ++++++++++++++++++ .../src/iree/split_mlir/SplitMlirPyExt.cpp | 37 ++++ .../src/iree/split_mlir/__init__.py | 14 ++ .../src/iree/split_mlir/_split_mlir.pyi | 12 ++ .../src/iree/split_mlir/execution.py | 42 ++++ .../src/iree/split_mlir/iree_execution.py | 122 ++++++++++++ .../split_mlir/test/execution/CMakeLists.txt | 18 ++ .../iree/split_mlir/test/execution/entry.mlir | 14 ++ .../test/execution/execution_test.py | 57 ++++++ .../iree/split_mlir/test/execution/f1.mlir | 10 + .../iree/split_mlir/test/execution/f2.mlir | 10 + .../split_mlir/src/iree/split_mlir/types.py | 19 ++ 13 files changed, 564 insertions(+) create mode 100644 experimental/split_mlir/src/iree/split_mlir/OperationListImpl.h create mode 100644 experimental/split_mlir/src/iree/split_mlir/SplitMlirPyExt.cpp create mode 100644 experimental/split_mlir/src/iree/split_mlir/__init__.py create mode 100644 experimental/split_mlir/src/iree/split_mlir/_split_mlir.pyi create mode 100644 experimental/split_mlir/src/iree/split_mlir/execution.py create mode 100644 experimental/split_mlir/src/iree/split_mlir/iree_execution.py create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/execution/CMakeLists.txt create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/execution/entry.mlir create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/execution/execution_test.py create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/execution/f1.mlir create mode 100644 experimental/split_mlir/src/iree/split_mlir/test/execution/f2.mlir create mode 100644 experimental/split_mlir/src/iree/split_mlir/types.py diff --git a/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt b/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt index 63627bf3f76c4..d59a3df68730e 100644 --- a/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt +++ b/experimental/split_mlir/src/iree/split_mlir/CMakeLists.txt @@ -50,3 +50,31 @@ iree_compiler_register_plugin( TARGET ::registration ) + +iree_pyext_module( + NAME + PyExt + MODULE_NAME _split_mlir + SRCS + "OperationListImpl.h" + "SplitMlirPyExt.cpp" + DEPS + MLIRFuncDialect + MLIRIR + MLIRAsmParser + iree::compiler::Tools::init_passes_and_dialects +) + +iree_py_library( + NAME + split_mlir_py + SRCS + "__init__.py" + "_split_mlir.pyi" + "execution.py" + "iree_execution.py" + "types.py" + DEPS + MLIRPythonModules + ::PyExt +) diff --git a/experimental/split_mlir/src/iree/split_mlir/OperationListImpl.h b/experimental/split_mlir/src/iree/split_mlir/OperationListImpl.h new file mode 100644 index 0000000000000..df2f4180a2d37 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/OperationListImpl.h @@ -0,0 +1,181 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "iree/compiler/Tools/init_dialects.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace iree { +namespace split_mlir { + +using OpId = std::string; +using OpIndex = size_t; +using ResultIndex = size_t; +using ResultId = std::tuple; +using Arguments = std::vector; +using OperationList = std::vector>; + +ResultIndex getResultIndex(OpOperand& operand) { + OpResult opResult = operand.get().dyn_cast(); + if (opResult) { + return opResult.getResultNumber(); + } + + BlockArgument blockArgument = operand.get().dyn_cast(); + assert(blockArgument); + return blockArgument.getArgNumber(); +} + +FailureOr getDefiningOpIndex( + OpOperand& operand, Block& block, + const std::unordered_map& operationInBlockIndexMap) { + Value value = operand.get(); + if (value.isa()) { + return 0; + } + + OpResult opResult = value.dyn_cast(); + if (!opResult) { + operand.getOwner()->emitError( + Twine("Operand ") + std::to_string(operand.getOperandNumber()) + + "is neigher a block argument or a result of an operation"); + return failure(); + } + if (value.getDefiningOp()->getBlock() != &block) { + operand.getOwner()->emitError( + "Can't extract call graph for block that is not isolated from above."); + return failure(); + } + + auto it = operationInBlockIndexMap.find(value.getDefiningOp()); + assert(it != operationInBlockIndexMap.end()); + return it->second; +} + +std::string getOpId(Operation& op) { + func::CallOp callOp = dyn_cast(op); + if (callOp) { + return (Twine("call ") + callOp.getCallee()).str(); + } + + if (isa(op)) { + return "return"; + } + + return op.getName().getStringRef().str(); +} + +FailureOr extractOperationList(Block& block) { + OperationList res; + // Block arguments don't depend on anything. + res.emplace_back(); + // Index inside the block. + std::unordered_map operationInBlockIndexMap; + + for (auto opIt : llvm::enumerate(block)) { + operationInBlockIndexMap.insert({&opIt.value(), opIt.index() + 1}); + OpId id = getOpId(opIt.value()); + Arguments arguments; + for (OpOperand& operand : opIt.value().getOpOperands()) { + FailureOr opIndex = + getDefiningOpIndex(operand, block, operationInBlockIndexMap); + FailureOr resultIndex = getResultIndex(operand); + if (failed(opIndex) || failed(resultIndex)) { + return failure(); + } + arguments.emplace_back(opIndex.value(), resultIndex.value()); + } + res.emplace_back(id, arguments); + } + + return res; +} + +FailureOr> loadMlir(const char* mlirFilePath, + MLIRContext& context) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(mlirFilePath); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return failure(); + } + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + return parseSourceFile(sourceMgr, &context); +} + +func::FuncOp findFunction(Operation* root, StringRef name) { + func::FuncOp res; + root->walk([&](func::FuncOp op) { + if (op.getSymName() == name) { + res = op; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return res; +} + +FailureOr extractOperationList(ModuleOp moduleOp, + StringRef functionName) { + func::FuncOp funcOp = findFunction(moduleOp.getOperation(), functionName); + Region* region = funcOp.getCallableRegion(); + if (!region) { + funcOp.emitError("No callable region found."); + return failure(); + } + if (region->getBlocks().size() != 1) { + funcOp.emitError("Blocks count must be exactly 1."); + return failure(); + } + return extractOperationList(region->front()); +} + +FailureOr extractOperationList(const char* mlirFilePath, + StringRef functionName, + MLIRContext& context) { + auto moduleOp = loadMlir(mlirFilePath, context); + if (failed(moduleOp)) { + return failure(); + } + + return extractOperationList(moduleOp->get(), functionName); +} + +std::unique_ptr makeMlirContext() { + mlir::DialectRegistry registry; + mlir::iree_compiler::registerAllDialects(registry); + auto context = std::make_unique(registry); + return context; +} + +FailureOr extractOperationList(const char* mlirFilePath, + StringRef functionName) { + auto context = makeMlirContext(); + return extractOperationList(mlirFilePath, functionName, *context); +} + +} // namespace split_mlir +} // namespace iree +} // namespace mlir diff --git a/experimental/split_mlir/src/iree/split_mlir/SplitMlirPyExt.cpp b/experimental/split_mlir/src/iree/split_mlir/SplitMlirPyExt.cpp new file mode 100644 index 0000000000000..d1b5560fe660c --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/SplitMlirPyExt.cpp @@ -0,0 +1,37 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include + +#include "OperationListImpl.h" + +namespace py = pybind11; + +namespace mlir { +namespace iree { +namespace split_mlir { + +PYBIND11_MODULE(_split_mlir, m) { + m.doc() = "Split MLIR C++ extension"; + + m.def( + "extract_operation_list", + [](const std::string& mlirFilePath, const std::string& functionName) { + auto res = extractOperationList(mlirFilePath.c_str(), functionName); + if (failed(res)) { + throw std::runtime_error(""); + } + return res.value(); + }, + py::arg("mlir_file_path"), py::arg("function_name")); +} + +} // namespace split_mlir +} // namespace iree +} // namespace mlir diff --git a/experimental/split_mlir/src/iree/split_mlir/__init__.py b/experimental/split_mlir/src/iree/split_mlir/__init__.py new file mode 100644 index 0000000000000..f66106cd2b4a6 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._split_mlir import extract_operation_list +from .execution import execute_operation_list +from .iree_execution import IreeExecutor, execute_mlir_with_iree + +__all__ = [ + "execute_operation_list", "execute_mlir_with_iree", + "extract_operation_list", "IreeExecutor" +] diff --git a/experimental/split_mlir/src/iree/split_mlir/_split_mlir.pyi b/experimental/split_mlir/src/iree/split_mlir/_split_mlir.pyi new file mode 100644 index 0000000000000..f21ca121c93e8 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/_split_mlir.pyi @@ -0,0 +1,12 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .types import OperationList + + +def extract_operation_list(mlir_file_path: str, + function_name: str) -> OperationList: + ... diff --git a/experimental/split_mlir/src/iree/split_mlir/execution.py b/experimental/split_mlir/src/iree/split_mlir/execution.py new file mode 100644 index 0000000000000..ddf8988fb39e8 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/execution.py @@ -0,0 +1,42 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List, Optional +from .types import OpArguments, Tensor, ExecuteOp, OperationList + + +def collect_arguments(arguments: OpArguments, + results: List[List[Tensor]]) -> List[Tensor]: + return [results[arg[0]][arg[1]] for arg in arguments] # type: ignore + + +def execute_operation_list( + input: List[Tensor], + operation_list: OperationList, + execute_op: ExecuteOp, + override_results: Optional[List[List[Tensor]]] = None +) -> List[List[Tensor]]: + """Algorithm to execute a call list. + + Parameters + ---------- + input : Input of the graph. + execute_op : Callable that executes an operation from the graph. + override_results : When execting operations override arguments with this values, + instead of using the computed resuts from previous functions. + + Returns + ------- + All results from all operations is the graph are in the same order + as they appear in `operation_list`. `input` is prepened to the result. + """ + results = [input] + for op in operation_list[1:]: + arguments = collect_arguments( + arguments=op[1], + results=results if override_results is None else override_results) + results.append(execute_op(op[0], arguments)) + return results diff --git a/experimental/split_mlir/src/iree/split_mlir/iree_execution.py b/experimental/split_mlir/src/iree/split_mlir/iree_execution.py new file mode 100644 index 0000000000000..cfcfe92a3a903 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/iree_execution.py @@ -0,0 +1,122 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List, Callable, Tuple, Dict, Optional, Any +from .types import Tensor +import iree.runtime +from iree.runtime import VmModule, HalDevice, load_vm_module +from .execution import execute_operation_list +from tempfile import TemporaryDirectory +from iree.compiler.tools import compile_file +import os +from pathlib import Path +from collections import namedtuple +from ._split_mlir import extract_operation_list +from numbers import Number + +VmfbFilePath = str +MlirFilePath = str +FunctionName = str + + +class IreeExecutor: + """Executor for IREE that implements the `.types.ExecuteOp` interface.""" + + def __init__(self, device: HalDevice, + resolve_function: Callable[[FunctionName], Tuple[VmfbFilePath, + FunctionName]]): + """ + Parameters + ---------- + resolve_function : Resolves a function name that is called in the entry MLIR to + the vmfb file where it can be found under another name. + """ + self.device = device + self.resolve_function = resolve_function + + def __call__(self, op_id: str, operands: List[Tensor]) -> List[Tensor]: + if op_id.startswith("call "): + function_name = op_id.split(" ", 2)[1] + vmfb_file_path, vmfb_function_name = self.resolve_function(function_name) + config = iree.runtime.Config(device=self.device) + with open(vmfb_file_path, "rb") as f: + vm_flatbuffer = f.read() + vm_module_fb_bytes = VmModule.from_flatbuffer(config.vm_instance, + vm_flatbuffer) + vm_module = load_vm_module(vm_module_fb_bytes, config) + res = getattr(vm_module, vmfb_function_name)(*operands) + if isinstance(res, (iree.runtime.DeviceArray, Number)): + res = [res] + return res + if op_id == "return": + return operands + raise RuntimeError(f"Invalid op_id \"{op_id}\".") + + +def mlir_to_vmfb_file_path(mlir_file_path: str) -> str: + return f"{Path(mlir_file_path).stem}.vmfb" + + +def execute_mlir_with_iree(input: List[Tensor], + mlir_path_function_pairs: List[Tuple[MlirFilePath, + FunctionName]], + compile_kwargs: Dict[str, Any], + device: HalDevice, + override_results: Optional[List[ + List[Tensor]]] = None, + artifact_dir: Optional[str] = None) -> List[Tensor]: + """Executes an MLIR program that is split accorss multiple MLIR files. + Parameters + ---------- + mlir_path_function_pairs : List of MLIR files and the function they contain. + The first element is the entry MLIR and function. + It is expected that a name of function called in the entry function correspnd + to an MLIR file with the same name without file name extension. + compile_kwargs : Compile arguments to pass to iree.compiler.tools.compile_file. + artifact_dir : Where to put temporary files. + Defaults to creating a unique temporary directory that is deleted on completion. + + See: `execute_operation_list` + """ + if artifact_dir is None: + with TemporaryDirectory() as temp_dir: + return execute_mlir_with_iree( + input=input, + mlir_path_function_pairs=mlir_path_function_pairs, + override_results=override_results, + compile_kwargs=compile_kwargs, + device=device, + artifact_dir=temp_dir) + + entry_mlir_file_path = mlir_path_function_pairs[0][0] + entry_function_name = mlir_path_function_pairs[0][1] + FunctionDescription = namedtuple( + "FunctionDescription", + ["mlir_file_path", "vmfb_file_path", "function_name"]) + function_map = { + Path(Path(p[0]).name).stem: FunctionDescription( + p[0], os.path.join(artifact_dir, mlir_to_vmfb_file_path(p[0])), p[1]) + for p in mlir_path_function_pairs + } + for i in range(1, len(mlir_path_function_pairs)): + function_description = function_map[Path( + Path(mlir_path_function_pairs[i][0]).name).stem] + compile_file(function_description.mlir_file_path, + output_file=function_description.vmfb_file_path, + **compile_kwargs) + + def resolve_function( + function_name: FunctionName) -> Tuple[VmfbFilePath, FunctionName]: + func_desc = function_map[function_name] + return (func_desc.vmfb_file_path, func_desc.function_name) + + executor = IreeExecutor(device=device, resolve_function=resolve_function) + operation_list = extract_operation_list(mlir_file_path=entry_mlir_file_path, + function_name=entry_function_name) + return execute_operation_list(operation_list=operation_list, + execute_op=executor, + input=input, + override_results=override_results) diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/CMakeLists.txt b/experimental/split_mlir/src/iree/split_mlir/test/execution/CMakeLists.txt new file mode 100644 index 0000000000000..77611fc4278a6 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +iree_local_py_test( + NAME + execution_test + SRC + execution_test.py + PACKAGE_DIRS + "${IREE_BINARY_DIR}/compiler/bindings/python" + "${IREE_BINARY_DIR}/runtime/bindings/python" + "${IREE_BINARY_DIR}/compiler/plugins/split_mlir" +) diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/entry.mlir b/experimental/split_mlir/src/iree/split_mlir/test/execution/entry.mlir new file mode 100644 index 0000000000000..ea6cd8b5f4e58 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/entry.mlir @@ -0,0 +1,14 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +func.func nested @f1(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) +func.func nested @f2(%arg0: tensor<1xf32>) -> tensor<1xf32> + +func.func @caller(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) { + %0:2 = call @f1(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) + %1 = call @f2(%0#0) : (tensor<1xf32>) -> tensor<1xf32> + return %arg1, %1 : tensor<1xf32>, tensor<1xf32> +} diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/execution_test.py b/experimental/split_mlir/src/iree/split_mlir/test/execution/execution_test.py new file mode 100644 index 0000000000000..98e5b766eacd6 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/execution_test.py @@ -0,0 +1,57 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +from iree.split_mlir import extract_operation_list, execute_mlir_with_iree +from typing import List, Any +import os +import iree.runtime +import numpy as np + + +def assert_nested_array_equals(a: List[Any], b: List[Any]): + assert a == b, f"{a} != {b}" + + +class ExecutionTest(unittest.TestCase): + + def test_extract_operation_list(self): + expected_operation_list = [ + ("", []), + ("call f1", [(0, 0), (0, 0)]), + ("call f2", [(1, 0)]), + ("return", [(0, 1), (2, 0)]), + ] + operation_list = extract_operation_list(mlir_file_path=os.path.join( + os.path.dirname(__file__), "entry.mlir"), + function_name="caller") + assert_nested_array_equals(expected_operation_list, operation_list) + + def test_mlir_execution(self): + mlir_path_function_pairs = [ + (os.path.join(os.path.dirname(__file__), "entry.mlir"), "caller"), + (os.path.join(os.path.dirname(__file__), "f1.mlir"), "f1"), + (os.path.join(os.path.dirname(__file__), "f2.mlir"), "main"), + ] + compile_kwargs = { + "target_backends": ["llvm-cpu"], + } + device = iree.runtime.get_device("local-task") + input = [np.array([1], dtype=np.float32), np.array([2], dtype=np.float32)] + results = execute_mlir_with_iree( + input=input, + mlir_path_function_pairs=mlir_path_function_pairs, + compile_kwargs=compile_kwargs, + device=device) + expected_output = [ + np.array([2], dtype=np.float32), + np.array([4], dtype=np.float32) + ] + assert_nested_array_equals(results[-1], expected_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/f1.mlir b/experimental/split_mlir/src/iree/split_mlir/test/execution/f1.mlir new file mode 100644 index 0000000000000..8ca9609123a7a --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/f1.mlir @@ -0,0 +1,10 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +func.func @f1(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) { + %0 = arith.addf %arg0, %arg1: tensor<1xf32> + return %0, %0 : tensor<1xf32>, tensor<1xf32> +} diff --git a/experimental/split_mlir/src/iree/split_mlir/test/execution/f2.mlir b/experimental/split_mlir/src/iree/split_mlir/test/execution/f2.mlir new file mode 100644 index 0000000000000..6b0bc790fdbb8 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/test/execution/f2.mlir @@ -0,0 +1,10 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +func.func @main(%arg0: tensor<1xf32>) -> tensor<1xf32> { + %0 = arith.addf %arg0, %arg0 : tensor<1xf32> + return %0 : tensor<1xf32> +} diff --git a/experimental/split_mlir/src/iree/split_mlir/types.py b/experimental/split_mlir/src/iree/split_mlir/types.py new file mode 100644 index 0000000000000..418bc207ab911 --- /dev/null +++ b/experimental/split_mlir/src/iree/split_mlir/types.py @@ -0,0 +1,19 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List, TypeVar, Callable, Tuple +from numbers import Integral + +Tensor = TypeVar("Tensor") +OpId = TypeVar("OpId") +ExecuteOp = Callable[[OpId, List[Tensor]], List[Tensor]] +OperationIndex = Integral +ResultIndex = Integral +"""Description of the dependencies of an operation.""" +OpArguments = List[Tuple[OperationIndex, ResultIndex]] +Operation = Tuple[OpId, OpArguments] +"""Describes a dependency graph of operations.""" +OperationList = List[Operation] From 117bece5d7d5a40e4a0972d721859c9529d8f5ff Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Wed, 22 Mar 2023 19:18:50 -0400 Subject: [PATCH 13/44] [Preprocessing] Add pass to generalize 2d convolutions This can be useful when trying to do layout propagation and guaranteeing specific fusion at time (use with caution). --- .../compiler/Preprocessing/Common/BUILD.bazel | 1 + .../Preprocessing/Common/CMakeLists.txt | 1 + .../Common/GeneralizeConvolutions.cpp | 68 +++++++++++++++++++ .../compiler/Preprocessing/Common/Passes.h | 3 + .../compiler/Preprocessing/Common/Passes.td | 6 ++ 5 files changed, 79 insertions(+) create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index f8582d347074c..3490034349f1a 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -34,6 +34,7 @@ iree_compiler_cc_library( "ConvertConvNchwToNhwc.cpp", "ConvertLinalgMatmulToMmt.cpp", "GeneralizeAndFuse.cpp", + "GeneralizeConvolutions.cpp", "MakeSingleDispatchForFunction.cpp", "PadLinalgOps.cpp", "PassDetail.h", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index b5bd22be2d721..1ef168c44ce39 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -30,6 +30,7 @@ iree_cc_library( "ConvertConvNchwToNhwc.cpp" "ConvertLinalgMatmulToMmt.cpp" "GeneralizeAndFuse.cpp" + "GeneralizeConvolutions.cpp" "MakeSingleDispatchForFunction.cpp" "PadLinalgOps.cpp" "PassDetail.h" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp new file mode 100644 index 0000000000000..ada980a788edc --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeConvolutions.cpp @@ -0,0 +1,68 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Preprocessing/Common/PassDetail.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +namespace { + +template +class GeneralizeTargetNamedOp final : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LinalgOpType linalgOp, + PatternRewriter &rewriter) const override { + FailureOr genericOp = + linalg::generalizeNamedOp(rewriter, linalgOp); + if (failed(genericOp)) return failure(); + return success(); + } +}; + +struct GeneralizeConvolutionsPass + : GeneralizeConvolutionsBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(&getContext()); + patterns.insert>(context); + patterns.insert>(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr createGeneralizeConvolutionsPass() { + return std::make_unique(); +} + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h index 475f8bda54f2d..f7ea5661464cf 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h @@ -36,6 +36,9 @@ std::unique_ptr createGeneralizeAndFusePass(); std::unique_ptr> createMakeSingleDispatchForFunctionPass(); +// A pass to generalize all conv-like ops. +std::unique_ptr createGeneralizeConvolutionsPass(); + // A pass to pad linalg ops to the next integer multiple of `paddingSize`. std::unique_ptr createPadLinalgOpsToIntegerMultiplePass(); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index 0e0bc849675fa..04e615e763333 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -40,6 +40,12 @@ def MakeSingleDispatchForFunction : let constructor = "mlir::iree_compiler::IREE::createMakeSingleDispatchForFunctionPass()"; } +def GeneralizeConvolutions : + Pass<"iree-preprocessing-generalize-convolutions", ""> { + let summary = "Generalize all convolution ops"; + let constructor = "mlir::iree_compiler::IREE::createGeneralizeConvolutionsPass()"; +} + def PadLinalgOps : Pass<"iree-preprocessing-pad-linalg-ops", ""> { let summary = "Pad linalg ops to the next integer multiple of paddingSize."; From 78bf32443c7387cc9060c6e286d30314d476e361 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 16 Feb 2023 14:57:11 -0500 Subject: [PATCH 14/44] [Preprocessing] Add a pass to convert convolutions to channels last This pass is the spiritual successor to `convert-conv-nchw-to-nhwc` focused on generalizing to enable data tiling and more robust layout propagation, as well as supporting non-named convolutions as well. Currently this includes some baked in generalization patterns and does not support padding. Tile size selection currently is pass-wide, but there is limited attribute control to enable fully transposing. Further generalizations should aim to write this pass by allowing per-op tile size control. --- .../compiler/Preprocessing/Common/BUILD.bazel | 3 +- .../Preprocessing/Common/CMakeLists.txt | 1 + .../Common/ConvertConvToChannelsLast.cpp | 890 ++++++++++++++++++ .../compiler/Preprocessing/Common/Passes.h | 3 + .../compiler/Preprocessing/Common/Passes.td | 12 + 5 files changed, 908 insertions(+), 1 deletion(-) create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index 3490034349f1a..08795e578b27b 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -32,8 +32,9 @@ iree_compiler_cc_library( srcs = [ "ConvertConv2DToImg2Col.cpp", "ConvertConvNchwToNhwc.cpp", + "ConvertConvToChannelsLast.cpp", "ConvertLinalgMatmulToMmt.cpp", - "GeneralizeAndFuse.cpp", + "GeneralizeAndFuse.cpp", "GeneralizeConvolutions.cpp", "MakeSingleDispatchForFunction.cpp", "PadLinalgOps.cpp", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index 1ef168c44ce39..1b83bc1ea1d74 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -28,6 +28,7 @@ iree_cc_library( SRCS "ConvertConv2DToImg2Col.cpp" "ConvertConvNchwToNhwc.cpp" + "ConvertConvToChannelsLast.cpp" "ConvertLinalgMatmulToMmt.cpp" "GeneralizeAndFuse.cpp" "GeneralizeConvolutions.cpp" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp new file mode 100644 index 0000000000000..4a8b833bfb1a6 --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp @@ -0,0 +1,890 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Preprocessing/Common/PassDetail.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-preprocessing-convert-conv-to-channels-last" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +static const StringLiteral fullTileTransposeMarker = "__fully_transpose_tile__"; + +using TransposeIndices = SmallVector; +using ConvBuilderFn = std::function newDimOrder, + SmallVector newIteratorTypes)>; +using linalg::detail::MatchConvolutionResult; + +static Value defaultConvBuilderFn( + OpBuilder &b, Location loc, linalg::LinalgOp srcConv, Value input, + Value filter, Value output, AffineMap inputMap, AffineMap filterMap, + AffineMap outputMap, SmallVector newDimOrder, + SmallVector newIteratorTypes) { + AffineMap newInputMap = inputMap; + AffineMap newFilterMap = filterMap; + AffineMap newOutputMap = outputMap; + if (!newDimOrder.empty()) { + DenseMap dimMap; + for (auto [newDim, oldDim] : llvm::enumerate(newDimOrder)) + dimMap[b.getAffineDimExpr(oldDim)] = b.getAffineDimExpr(newDim); + newInputMap = inputMap.replace(dimMap, + /*numResultDims=*/newDimOrder.size(), + /*numResultSymbols=*/0); + newFilterMap = filterMap.replace(dimMap, + /*numResultDims=*/newDimOrder.size(), + /*numResultSymbols=*/0); + newOutputMap = outputMap.replace(dimMap, + /*numResultDims=*/newDimOrder.size(), + /*numResultSymbols=*/0); + } + SmallVector iterators = srcConv.getIteratorTypesArray(); + iterators.append(newIteratorTypes); + auto genericConv = b.create( + loc, output.getType(), ValueRange{input, filter}, output, + ArrayRef{newInputMap, newFilterMap, newOutputMap}, iterators); + IRMapping mapper; + srcConv->getRegion(0).cloneInto(&genericConv.getRegion(), mapper); + return genericConv.getResult(0); +} + +template +static Value namedConvBuilderFn( + OpBuilder &b, Location loc, linalg::LinalgOp srcConv, Value input, + Value filter, Value output, AffineMap inputMap, AffineMap filterMap, + AffineMap outputMap, SmallVector newDimOrder, + SmallVector newIteratorTypes) { + sourceNamedConvTy namedConv = cast(srcConv); + return b + .create( + loc, output.getType(), ValueRange{input, filter}, output, + namedConv.getStrides(), namedConv.getDilations()) + .getResult(0); +} + +static TransposeIndices getNormalizedIndices(TransposeIndices targetIndices) { + int startDim = *std::min_element(targetIndices.begin(), targetIndices.end()); + TransposeIndices normalized(targetIndices.size()); + for (auto i : llvm::enumerate(targetIndices)) + normalized[i.index()] = i.value() - startDim; + return normalized; +} + +static TransposeIndices invertIndices(TransposeIndices targetIndices) { + int startDim = *std::min_element(targetIndices.begin(), targetIndices.end()); + TransposeIndices inverted(targetIndices.size()); + for (auto i : llvm::enumerate(targetIndices)) { + inverted[i.value() - startDim] = i.index() + startDim; + } + return inverted; +} + +static bool isInnerIdentityIndices(TransposeIndices indices, int64_t rank) { + return indices.empty() || + (llvm::all_of(llvm::enumerate(indices), + [indices](auto e) { + if (e.index() == 0) return true; + return indices[e.index() - 1] < e.value(); + }) && + indices.back() == rank - 1); +} + +// Helper to shuffle vectors according to the transpose indices. +template +static SmallVector shuffleFromIndices(SmallVector unshuffled, + TransposeIndices targetIndices) { + int startDim = *std::min_element(targetIndices.begin(), targetIndices.end()); + SmallVector shuffled(unshuffled); + for (auto i : llvm::enumerate(targetIndices)) { + shuffled[i.index() + startDim] = unshuffled[i.value()]; + } + return shuffled; +} + +template +static SmallVector getPackedVector(SmallVector vec, + TransposeIndices targetIndices) { + SmallVector packedShape; + for (auto [i, val] : llvm::enumerate(vec)) + if (!llvm::is_contained(targetIndices, i)) packedShape.push_back(val); + for (auto i : targetIndices) packedShape.push_back(vec[i]); + return packedShape; +} + +static SmallVector getUntiledPackReassociationMap( + TransposeIndices targetIndices, int64_t rank) { + int startDim = *std::min_element(targetIndices.begin(), targetIndices.end()); + int dimCount = targetIndices.size(); + SmallVector reassociationMap; + for (int i = 0; i <= startDim; i++) reassociationMap.push_back({i}); + for (int i = startDim + 1; i < dimCount + startDim + 1; i++) + reassociationMap[startDim].push_back(i); + for (int i = dimCount + startDim + 1; i < dimCount + rank; i++) + reassociationMap.push_back({i}); + return reassociationMap; +} + +// Transpose the given tensor based on the given transpose indices. Marks the +// created transpose based on the propagation direction. +static std::tuple, AffineMap> +createTransposeAsTensorPack( + PatternRewriter &rewriter, Location loc, Value input, AffineMap inputMap, + TransposeIndices targetIndices, int tilingFactor, + llvm::DenseMap innerDimToDomainDim) { + if (isInnerIdentityIndices(targetIndices, inputMap.getNumResults())) + return std::make_tuple(input, std::nullopt, inputMap); + + RankedTensorType inType = input.getType().cast(); + auto elementType = inType.getElementType(); + auto inputShape(inType.getShape()); + + SmallVector transposedTileSizes( + targetIndices.size(), rewriter.getIndexAttr(tilingFactor)); + if (tilingFactor <= 0) { + for (auto [index, i] : llvm::enumerate(targetIndices)) { + if (ShapedType::isDynamic(inputShape[i])) + transposedTileSizes[index] = + rewriter.create(loc, input, i).getResult(); + else + transposedTileSizes[index] = rewriter.getIndexAttr(inputShape[i]); + } + } + + // Pack the input tensor. + auto empty = tensor::PackOp::createDestinationTensor( + rewriter, loc, input, transposedTileSizes, targetIndices, + SmallVector{}); + auto packedInput = rewriter.create( + loc, input, empty, targetIndices, transposedTileSizes, + /*padding=*/std::nullopt, SmallVector{}); + + SmallVector mapResults(inputMap.getResults()); + AffineMap transposedMap; + + Value packedOperand = packedInput; + // Collapse the unit dims created by tensor.pack. + if (tilingFactor <= 0) { + auto reassociationMap = + getUntiledPackReassociationMap(targetIndices, inType.getRank()); + auto transposedInputShape = + getPackedVector(llvm::to_vector(inputShape), targetIndices); + packedOperand = + rewriter + .create( + loc, RankedTensorType::get(transposedInputShape, elementType), + packedOperand, reassociationMap) + .getResult(); + transposedMap = + AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(), + getPackedVector(mapResults, targetIndices), + input.getContext()); + } else { + for (auto innerDim : targetIndices) { + mapResults.push_back(rewriter.getAffineDimExpr( + innerDimToDomainDim[inputMap.getDimPosition(innerDim)])); + } + transposedMap = AffineMap::get( + inputMap.getNumDims() + innerDimToDomainDim.size(), + inputMap.getNumSymbols(), mapResults, input.getContext()); + } + + return std::make_tuple(packedOperand, packedInput, transposedMap); +} + +// Transpose the given tensor based on the given transpose indices. Marks the +// created transpose based on the propagation direction. +static Value createTransposeAsTensorUnPack(PatternRewriter &rewriter, + Location loc, Value output, + tensor::PackOp packOp, + int tilingFactor) { + Value packedOutput = output; + if (tilingFactor <= 0) { + RankedTensorType outType = output.getType().cast(); + auto elementType = outType.getElementType(); + auto outputShape(outType.getShape()); + int64_t rank = outType.getRank(); + TransposeIndices targetIndices(packOp.getInnerDimsPos()); + + int startDim = + *std::min_element(targetIndices.begin(), targetIndices.end()); + SmallVector expandedOutputShape; + for (int i = 0, e = startDim; i < e; i++) + expandedOutputShape.push_back(outputShape[i]); + for (int i = 0, e = targetIndices.size(); i < e; i++) + expandedOutputShape.push_back(1); + for (int i = startDim, e = rank; i < e; i++) + expandedOutputShape.push_back(outputShape[i]); + + auto reassociationMap = getUntiledPackReassociationMap(targetIndices, rank); + packedOutput = + rewriter + .create( + loc, RankedTensorType::get(expandedOutputShape, elementType), + output, reassociationMap) + .getResult(); + } + + Value empty = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, packedOutput, packOp.getMixedTiles(), + packOp.getInnerDimsPos(), packOp.getOuterDimsPerm()); + + auto unpackedOutput = rewriter.create( + loc, packedOutput, empty, packOp.getInnerDimsPos(), + packOp.getMixedTiles(), packOp.getOuterDimsPerm()); + unpackedOutput->setAttr("__unpack__", rewriter.getUnitAttr()); + return unpackedOutput.getResult(); +} + +static TransposeIndices collectChannelTransposeIndices( + AffineMap map, SmallVector> transposeDimTargets) { + SmallVector channelIndices(transposeDimTargets.size()); + for (auto [index, result] : llvm::enumerate(map.getResults())) { + if (result.isa()) { + for (auto [channelVec, dimCategory] : + llvm::zip_equal(channelIndices, transposeDimTargets)) { + if (llvm::is_contained(dimCategory, + result.cast().getPosition())) { + channelVec.push_back(index); + break; + } + } + } + } + + TransposeIndices indices; + for (auto channelVec : channelIndices) indices.append(channelVec); + return indices; +} + +static LogicalResult transposeConvLikeLinalgOp( + PatternRewriter &rewriter, linalg::LinalgOp convOp, int tilingFactor, + ConvBuilderFn convBuilder = defaultConvBuilderFn) { + Location loc = convOp.getLoc(); + + linalg::ConvolutionDimensions convDims; + auto errString = getMatchConvolutionMessage( + linalg::detail::isConvolutionInterfaceImpl(convOp, &convDims)); + if (!errString.empty()) return failure(); + + if (convDims.inputChannel.size() > 1) return failure(); + + if (convDims.outputChannel.size() > 1) return failure(); + + // TODO: Support depthwise convolutions + if (!convDims.depth.empty()) return failure(); + + Value input = convOp->getOperand(0); + Value filter = convOp->getOperand(1); + Value output = convOp->getOperand(2); + + auto inputMap = convOp.getIndexingMapsArray()[0]; + auto filterMap = convOp.getIndexingMapsArray()[1]; + auto outputMap = convOp.getIndexingMapsArray()[2]; + + auto inputIndices = + collectChannelTransposeIndices(inputMap, {convDims.inputChannel}); + auto filterIndices = collectChannelTransposeIndices( + filterMap, {convDims.inputChannel, convDims.outputChannel}); + auto outputIndices = + collectChannelTransposeIndices(outputMap, {convDims.outputChannel}); + + // Don't transpose if there's no change to the op. + if (isInnerIdentityIndices(inputIndices, inputMap.getNumResults()) && + isInnerIdentityIndices(filterIndices, filterMap.getNumResults()) && + isInnerIdentityIndices(outputIndices, outputMap.getNumResults())) + return failure(); + + int nDims = outputMap.getNumDims(); + llvm::DenseMap innerDimsToDomainDims; + for (auto [index, dim] : llvm::enumerate(convDims.inputChannel)) { + innerDimsToDomainDims[dim] = nDims + index; + } + for (auto [index, dim] : llvm::enumerate(convDims.outputChannel)) { + innerDimsToDomainDims[dim] = nDims + index + convDims.inputChannel.size(); + } + + auto [transposedInput, inputPack, transposedInputMap] = + createTransposeAsTensorPack(rewriter, loc, input, inputMap, inputIndices, + tilingFactor, innerDimsToDomainDims); + auto [transposedFilter, filterPack, transposedFilterMap] = + createTransposeAsTensorPack(rewriter, loc, filter, filterMap, + filterIndices, tilingFactor, + innerDimsToDomainDims); + auto [transposedOutput, outputPack, transposedOutputMap] = + createTransposeAsTensorPack(rewriter, loc, output, outputMap, + outputIndices, tilingFactor, + innerDimsToDomainDims); + + // Don't transpose if there's no change to the op. + if (transposedInputMap == inputMap && transposedFilterMap == filterMap && + transposedOutputMap == outputMap) + return failure(); + + Value convDest = transposedOutput; + if (auto fillOp = output.getDefiningOp()) { + if (outputPack) { + auto outputDest = outputPack->getDest().getDefiningOp(); + auto elementType = outputDest.getType().getElementType(); + + auto dimToTileMapping = outputPack->getDimAndTileMapping(); + SmallVector mixedSizes = outputDest.getMixedSizes(); + SmallVector packedSizes; + for (auto [index, size] : llvm::enumerate(mixedSizes)) + if (!dimToTileMapping.count(index) || tilingFactor > 0) + packedSizes.push_back(size); + + auto emptyOp = + rewriter.create(loc, packedSizes, elementType); + + convDest = rewriter + .create(loc, fillOp.getInputs(), + emptyOp.getResult()) + .result(); + } + } + + SmallVector newDimOrder; + SmallVector newIteratorTypes; + if (tilingFactor <= 0) { + newDimOrder.append(convDims.batch); + newDimOrder.append(convDims.outputImage); + newDimOrder.append(convDims.outputChannel); + newDimOrder.append(convDims.filterLoop); + newDimOrder.append(convDims.inputChannel); + } else { + newIteratorTypes.append(convDims.inputChannel.size(), + utils::IteratorType::reduction); + newIteratorTypes.append(convDims.outputChannel.size(), + utils::IteratorType::parallel); + } + + Value transposedConvResult = + convBuilder(rewriter, loc, convOp, transposedInput, transposedFilter, + convDest, transposedInputMap, transposedFilterMap, + transposedOutputMap, newDimOrder, newIteratorTypes); + + Value returnToNCHW = transposedConvResult; + if (outputPack) { + returnToNCHW = createTransposeAsTensorUnPack( + rewriter, loc, transposedConvResult, *outputPack, tilingFactor); + } + + rewriter.replaceOp(convOp, returnToNCHW); + return success(); +} + +namespace { + +//===================================================================== +// Convolution packing patterns +//===================================================================== + +struct ConvertLinalgConvNchwFchw : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + ConvertLinalgConvNchwFchw(MLIRContext *context, PatternBenefit benefit = 2) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp, + PatternRewriter &rewriter) const override { + return transposeConvLikeLinalgOp( + rewriter, convOp, /*tilingFactor=*/-1, + namedConvBuilderFn); + } +}; + +struct ConvertLinalgPoolingNchwMax + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + ConvertLinalgPoolingNchwMax(MLIRContext *context, PatternBenefit benefit = 2) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(linalg::PoolingNchwMaxOp poolOp, + PatternRewriter &rewriter) const override { + return transposeConvLikeLinalgOp( + rewriter, poolOp, /*tilingFactor=*/-1, + namedConvBuilderFn); + } +}; + +struct ConvertLinalgPoolingNchwSum + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + ConvertLinalgPoolingNchwSum(MLIRContext *context, PatternBenefit benefit = 2) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(linalg::PoolingNchwSumOp poolOp, + PatternRewriter &rewriter) const override { + return transposeConvLikeLinalgOp( + rewriter, poolOp, /*tilingFactor=*/-1, + namedConvBuilderFn); + } +}; + +struct ConvertLinalgConvOp : OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + ConvertLinalgConvOp(MLIRContext *context, int tile, + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), + tilingFactor(tile) {} + + LogicalResult matchAndRewrite(linalg::LinalgOp op, + PatternRewriter &rewriter) const override { + if (op->hasAttr(fullTileTransposeMarker)) + return transposeConvLikeLinalgOp(rewriter, op, 0); + return transposeConvLikeLinalgOp(rewriter, op, tilingFactor); + } + + private: + int tilingFactor; +}; + +//===================================================================== +// Propagation patterns +//===================================================================== + +class BubbleUpPackThroughPadOp final : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + auto padOp = packOp.getSource().getDefiningOp(); + if (!padOp) return failure(); + + if (!padOp.getResult().hasOneUse()) return failure(); + + // TODO: Enable padding. + if (packOp.getPaddingValue()) return failure(); + + // TODO: Enable outer dims perm. + if (!packOp.getOuterDimsPerm().empty()) return failure(); + + // We want to move the pack not the insert_slice. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(padOp); + + Location loc = padOp->getLoc(); + auto mixedTiles = packOp.getMixedTiles(); + auto innerDimsPos = packOp.getInnerDimsPos(); + auto outerDimsPerm = packOp.getOuterDimsPerm(); + if (!packOp.getDest().getDefiningOp()) return failure(); + + // Bail out if one of the padded dimension is a tiled one. + llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); + llvm::SmallBitVector innerDims(paddedDims.size()); + for (int64_t dim : innerDimsPos) innerDims.flip(dim); + if (paddedDims.anyCommon(innerDims)) return failure(); + + Value paddingVal = padOp.getConstantPaddingValue(); + if (!paddingVal) return failure(); + + auto empty = tensor::PackOp::createDestinationTensor( + rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos, + outerDimsPerm); + Value packedSource = rewriter.create( + loc, padOp.getSource(), empty, innerDimsPos, mixedTiles, + /*padding=*/std::nullopt, outerDimsPerm); + + // If we have `outer_dims_perms` we need to adjust the padded dimensions. + SmallVector lowPad = padOp.getMixedLowPad(); + SmallVector highPad = padOp.getMixedHighPad(); + if (!outerDimsPerm.empty()) { + applyPermutationToVector(lowPad, outerDimsPerm); + applyPermutationToVector(highPad, outerDimsPerm); + } + // Add zero padding for the point loops. + size_t pointLoopsSize = innerDimsPos.size(); + lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); + highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); + + auto newPadOp = rewriter.create( + loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal, + padOp.getNofold()); + rewriter.replaceOp(packOp, newPadOp.getResult()); + return success(); + } +}; + +class BubbleUpPackThroughTensorInsertSlice final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + auto insertSliceOp = + packOp.getSource().getDefiningOp(); + if (!insertSliceOp) return failure(); + + if (!insertSliceOp.getResult().hasOneUse()) return failure(); + + // TODO: Enable rank reduced slice. + if (insertSliceOp.getSourceType().getRank() != + insertSliceOp.getDestType().getRank()) + return failure(); + + // TODO: Enable padding. + if (packOp.getPaddingValue()) return failure(); + + // TODO: Enable outer dims perm. + if (!packOp.getOuterDimsPerm().empty()) return failure(); + + // We want to move the pack not the insert_slice. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(insertSliceOp); + + Location loc = insertSliceOp->getLoc(); + auto mixedTiles = packOp.getMixedTiles(); + auto innerDimsPos = packOp.getInnerDimsPos(); + auto outerDimsPerm = packOp.getOuterDimsPerm(); + Value packOpDest = packOp.getDest(); + if (!packOpDest.hasOneUse()) return failure(); + if (auto emptyOp = packOpDest.getDefiningOp()) { + packOpDest = tensor::PackOp::createDestinationTensor( + rewriter, loc, insertSliceOp.getDest(), mixedTiles, innerDimsPos, + outerDimsPerm); + } else { + DominanceInfo dom(insertSliceOp); + if (!dom.properlyDominates(packOpDest, insertSliceOp)) return failure(); + } + + SmallVector mixedSliceTiles(packOp.getMixedTiles()); + + SmallVector mixedOffsets(insertSliceOp.getMixedOffsets()); + SmallVector mixedSizes(insertSliceOp.getMixedSizes()); + SmallVector mixedStrides(insertSliceOp.getMixedStrides()); + + for (auto [index, dimPos, mixedTileSize] : + llvm::zip_equal(llvm::seq(0, innerDimsPos.size()), + innerDimsPos, mixedTiles)) { + if (!getConstantIntValue(mixedStrides[dimPos])) return failure(); + + std::optional constTileSize = getConstantIntValue(mixedTileSize); + if (!constTileSize) return failure(); + + std::optional constOffset = + getConstantIntValue(mixedOffsets[dimPos]); + if (!constOffset) return failure(); + + std::optional constSize = + getConstantIntValue(mixedSizes[dimPos]); + if (!constOffset) return failure(); + + int64_t tileSize = *constTileSize; + int64_t offset = *constOffset; + int64_t size = *constSize; + + if ((size % tileSize != 0 || offset % tileSize != 0) && + (offset / tileSize > (size + offset) / tileSize)) + return failure(); + mixedSliceTiles[index] = + rewriter.getI64IntegerAttr(std::min(size, tileSize)); + mixedOffsets[dimPos] = rewriter.getI64IntegerAttr(offset / tileSize); + mixedSizes[dimPos] = + rewriter.getI64IntegerAttr(std::max(size / tileSize, 1)); + + mixedOffsets.push_back(rewriter.getI64IntegerAttr(offset % tileSize)); + mixedSizes.push_back( + rewriter.getI64IntegerAttr(std::min(size, tileSize))); + mixedStrides.push_back(rewriter.getI64IntegerAttr(1)); + } + + Value newDest = packOpDest; + if (!insertSliceOp.getDest().getDefiningOp()) { + newDest = rewriter.create( + loc, insertSliceOp.getDest(), packOpDest, innerDimsPos, mixedTiles, + /*padding=*/std::nullopt, outerDimsPerm); + } + + auto empty = tensor::PackOp::createDestinationTensor( + rewriter, loc, insertSliceOp.getSource(), mixedSliceTiles, innerDimsPos, + outerDimsPerm); + Value packedSlice = rewriter.create( + loc, insertSliceOp.getSource(), empty, innerDimsPos, mixedSliceTiles, + /*padding=*/std::nullopt, outerDimsPerm); + + rewriter.replaceOpWithNewOp( + packOp, packedSlice, newDest, mixedOffsets, mixedSizes, mixedStrides); + return success(); + } +}; + +//===================================================================== +// Generalization and folding patterns +//===================================================================== + +template +class GeneralizeUntiledPackOrUnPackOp final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PackOrUnPackOpTy op, + PatternRewriter &rewriter) const override { + if (!op.getMixedTiles().empty()) return failure(); + TransposeIndices perm(op.getOuterDimsPerm()); + if (std::is_same::value) + perm = invertIndices(perm); + rewriter.replaceOpWithNewOp(op, op.getSource(), + op.getDest(), perm); + return success(); + } +}; + +static SmallVector getTilingReassociationMap( + int64_t rank, llvm::DenseMap innerDims) { + SmallVector map; + int64_t nTiled = 0; + for (int64_t i = 0, e = rank; i < e; i++) { + if (innerDims.count(i)) { + map.push_back({i + nTiled++, i + nTiled}); + continue; + } + map.push_back({i + nTiled}); + } + return map; +} + +class GeneralizeUnPermutedPackOp final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + if (!packOp.getOuterDimsPerm().empty()) return failure(); + if (packOp.getPaddingValue()) return failure(); + + RankedTensorType srcType = + packOp.getSource().getType().cast(); + int64_t rank = srcType.getRank(); + auto innerDimsPos = packOp.getInnerDimsPos(); + llvm::DenseMap innerDims; + for (auto [index, innerDim] : llvm::enumerate(innerDimsPos)) + innerDims[innerDim] = index; + + llvm::DenseMap innerDimsToExpandedDims; + TransposeIndices perm; + int64_t nTiled = 0; + for (int i = 0, e = rank; i < e; i++) { + perm.push_back(i + nTiled); + if (innerDims.count(i)) innerDimsToExpandedDims[i] = i + ++nTiled; + } + for (auto i : innerDimsPos) perm.push_back(innerDimsToExpandedDims[i]); + + RankedTensorType destType = + packOp.getDest().getType().cast(); + SmallVector destShape(destType.getShape()); + applyPermutationToVector(destShape, invertPermutationVector(perm)); + + auto expand = rewriter.create( + packOp.getLoc(), + RankedTensorType::get(destShape, destType.getElementType()), + packOp.getSource(), getTilingReassociationMap(rank, innerDims)); + + rewriter.replaceOpWithNewOp(packOp, expand, + packOp.getDest(), perm); + return success(); + } +}; + +class GeneralizeUnPermutedUnPackOp final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, + PatternRewriter &rewriter) const override { + if (!unpackOp.getOuterDimsPerm().empty()) return failure(); + + if (!unpackOp.getDest().getDefiningOp()) return failure(); + + RankedTensorType destType = + unpackOp.getDest().getType().cast(); + int64_t rank = destType.getRank(); + auto innerDimsPos = unpackOp.getInnerDimsPos(); + llvm::DenseMap innerDims; + for (auto [index, innerDim] : llvm::enumerate(innerDimsPos)) + innerDims[innerDim] = index; + + TransposeIndices perm; + for (int i = 0, e = rank; i < e; i++) { + perm.push_back(i); + if (innerDims.count(i)) perm.push_back(rank + innerDims[i]); + } + + Location loc = unpackOp.getLoc(); + SmallVector mixedSizes = + tensor::getMixedSizes(rewriter, loc, unpackOp.getSource()); + applyPermutationToVector(mixedSizes, perm); + auto elType = getElementTypeOrSelf(unpackOp.getDest()); + + auto emptyOp = rewriter.create(loc, mixedSizes, elType); + + Value transpose = rewriter + .create( + loc, unpackOp.getSource(), emptyOp, perm) + ->getResult(0); + + rewriter.replaceOpWithNewOp( + unpackOp, destType, transpose, + getTilingReassociationMap(rank, innerDims)); + return success(); + } +}; + +class GeneralizeLinalgTransposeOp final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TransposeOp op, + PatternRewriter &rewriter) const override { + auto linalgOp = cast(*op); + auto transpose = + rewriter + .create( + op.getLoc(), op.getResult().getType(), op.getInput(), + op.getInit(), linalgOp.getIndexingMapsArray(), + linalgOp.getIteratorTypesArray(), + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); + rewriter.replaceOp(op, transpose); + return success(); + } +}; + +class FoldCancellingUnPackPackOps final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, + PatternRewriter &rewriter) const override { + return tensor::UnPackOp::canonicalize(unpackOp, rewriter); + } +}; + +class FoldCancellingPackUnPackOps final + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + return tensor::PackOp::canonicalize(packOp, rewriter); + } +}; + +struct ConvertConvToChannelsLastPass + : public ConvertConvToChannelsLastBase { + public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + LogicalResult initializeOptions(StringRef options) override { + if (failed(Pass::initializeOptions(options))) { + return failure(); + } + tilingFactor = tileSize; + return success(); + } + + void runOnOperation() override { + auto op = getOperation(); + MLIRContext *context = &getContext(); + + { + RewritePatternSet patterns(context); + if (tilingFactor < 0) { + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + } + patterns.insert(context, tilingFactor); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + { + RewritePatternSet patterns(context); + linalg::populateDataLayoutPropagationPatterns( + patterns, [](Operation *op) { return true; }); + patterns.insert(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + { + RewritePatternSet patterns(context); + patterns.insert(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + { + RewritePatternSet patterns(context); + patterns.add(context); + patterns.insert(context); + patterns.insert>(context); + patterns.insert>( + context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + + { + RewritePatternSet patterns(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) { + return signalPassFailure(); + } + } + } + + private: + int64_t tilingFactor; +}; + +} // namespace + +std::unique_ptr createConvertConvToChannelsLastPass() { + return std::make_unique(); +} + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h index f7ea5661464cf..a339fb37d5a2c 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h @@ -39,6 +39,9 @@ createMakeSingleDispatchForFunctionPass(); // A pass to generalize all conv-like ops. std::unique_ptr createGeneralizeConvolutionsPass(); +// Creates a pass to convert convolutions to channels last and propagate. +std::unique_ptr createConvertConvToChannelsLastPass(); + // A pass to pad linalg ops to the next integer multiple of `paddingSize`. std::unique_ptr createPadLinalgOpsToIntegerMultiplePass(); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index 04e615e763333..2948aaaab929c 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -46,6 +46,18 @@ def GeneralizeConvolutions : let constructor = "mlir::iree_compiler::IREE::createGeneralizeConvolutionsPass()"; } +def ConvertConvToChannelsLast : + Pass<"iree-preprocessing-convert-conv-to-channels-last", ""> { + let summary = "Convert linalg convolutions to channels last."; + let constructor = + "mlir::iree_compiler::IREE::createConvertConvToChannelsLastPass()"; + let options = [ + Option<"tileSize", "tile-size", "int", + /*default=*/"0", + "Specify the tiling factor">, + ]; +} + def PadLinalgOps : Pass<"iree-preprocessing-pad-linalg-ops", ""> { let summary = "Pad linalg ops to the next integer multiple of paddingSize."; From fcfb1d6f318c5e5449103358016b2003679d0433 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Wed, 10 May 2023 04:46:08 -0400 Subject: [PATCH 15/44] Fix embedded linker flag in python exposed iree-opt --- .../python/iree/compiler/tools/core.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/compiler/bindings/python/iree/compiler/tools/core.py b/compiler/bindings/python/iree/compiler/tools/core.py index 84bc978d44a42..91fafb31259aa 100644 --- a/compiler/bindings/python/iree/compiler/tools/core.py +++ b/compiler/bindings/python/iree/compiler/tools/core.py @@ -324,6 +324,43 @@ def query_available_targets(): # Preprocessing for SHARK (for now simply exposes iree-opt) +def build_opt_command_line( + input_file: str, tfs: TempFileSaver, options: CompilerOptions +) -> List[str]: + """Builds a command line for applying specified patterns. + + Args: + input_file: The input file name. + tfs: TempFileSaver. + options: Compiler options. + Returns: + List of strings of command line. + """ + iree_opt = find_tool("iree-opt") + cl = [ + iree_opt, + input_file, + ] + + # Output file. + if options.output_file: + cl.append(f"-o={options.output_file}") + + # Tool paths. + lld_path = find_tool("iree-lld") + cl.append(f"--iree-llvmcpu-embedded-linker-path={lld_path}") + + crash_reproducer_path = tfs.alloc_optional( + "core-reproducer.mlir", export_as=options.crash_reproducer_path + ) + if crash_reproducer_path: + cl.append(f"--mlir-pass-pipeline-crash-reproducer={crash_reproducer_path}") + + cl.extend(options.extra_args) + print(cl) + return cl + + def build_opt_command_line( input_file: str, tfs: TempFileSaver, options: CompilerOptions ) -> List[str]: From 9a16f595be03bb082bbda00f0138e69bb69f7427 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 20 Jul 2023 11:51:35 -0400 Subject: [PATCH 16/44] Add pattern for bubbling vector.bitcast through an enclosing scf.if <32 bit width types are handled on the SPIR-V side by introducing bitcasts to and from i32 and bubbling them to the center of the kernel hoping to cancel. This adds a pattern for a bitcast on the result of an scf.if, which comes from the way that padding is handled (transfer_read in the `then` branch, else yield a splat constant). --- .../Common/OptimizeVectorTransferPass.cpp | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp index 66acf8cb670ac..f26f3223f3ced 100644 --- a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp @@ -46,6 +46,84 @@ class TransposeUnitDimToShapeCast } }; +// TODO: Move this upstream +// Hoists a vector.bitcast op to the output of the enclosing scf.if +// +// This transforms IR like: +// %0 = scf.if %1 -> (vector<16xi8>) { +// %2 = memref.load %4[%c0] : memref> +// %3 = vector.bitcast %2 : vector<4xi32> to vector<16xi8> +// scf.yield %3 : vector<16xi8> +// } else { +// scf.yield %cst : vector<16xi8> +// } +// Into: +// %0 = scf.if %1 -> (vector<4xi32>) { +// %2 = memref.load %4[%c0] : memref> +// scf.yield %2 : vector<4xi32> +// } else { +// %3 = vector.bitcast %cst : vector<16xi8> to vector<4xi32> +// scf.yield %0 : vector<4xi32> +// } +// %3 = vector.bitcast %0 : vector<4xi32> to vector<16xi8> +struct BubbleUpBitCastOfScfIf : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::IfOp ifOp, + PatternRewriter &rewriter) const override { + // Bail on more than one result for now. + scf::YieldOp thenYield = ifOp.thenYield(); + if (!thenYield || thenYield.getNumOperands() != 1) + return failure(); + auto bitcastOp = thenYield.getOperand(0).getDefiningOp(); + // Bail out if no bitcast on the if then statement. + if (!bitcastOp) + return failure(); + + VectorType castSrcType = bitcastOp.getSourceVectorType(); + VectorType castDstType = bitcastOp.getResultVectorType(); + assert(castSrcType.getRank() == castDstType.getRank()); + // Skip 0-D vector. + if (castSrcType.getRank() == 0) + return failure(); + + int64_t castSrcLastDim = castSrcType.getShape().back(); + int64_t castDstLastDim = castDstType.getShape().back(); + // Require casting to more elements; + if (castSrcLastDim > castDstLastDim) + return failure(); + + Location loc = ifOp.getLoc(); + + auto bitcastedIfOp = + rewriter.create(loc, castSrcType, ifOp.getCondition()); + bitcastedIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + bitcastedIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + + scf::YieldOp newThenYield = bitcastedIfOp.thenYield(); + auto newBitcastOp = + newThenYield.getOperand(0).getDefiningOp(); + + newThenYield.setOperand(0, newBitcastOp.getSource()); + + auto newBitcast = rewriter.create( + loc, castDstType, bitcastedIfOp.getResult(0)); + + scf::YieldOp elseYield = bitcastedIfOp.elseYield(); + if (elseYield) { + OpBuilder::InsertionGuard elseGuard(rewriter); + rewriter.setInsertionPoint(elseYield); + + Value yieldSrc = elseYield.getOperand(0); + auto elseBitcast = + rewriter.create(loc, castSrcType, yieldSrc); + elseYield.setOperand(0, elseBitcast); + } + rewriter.replaceOp(ifOp, newBitcast); + return success(); + } +}; + static void loopInvariantCodeMotion(func::FuncOp funcOp) { // Walk through all loops in a function in innermost-loop-first order. This // way, we first LICM from the inner loop, and place the ops in @@ -89,6 +167,7 @@ struct OptimizeVectorTransferPass { RewritePatternSet patterns(&getContext()); vector::populateBubbleVectorBitCastOpPatterns(patterns); + patterns.add(&getContext()); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } From fa4b8de2c6feb433eec0d4fdb2366cd3db68362d Mon Sep 17 00:00:00 2001 From: Anush Elangovan Date: Fri, 5 Aug 2022 06:39:45 +0000 Subject: [PATCH 17/44] [CI] Add ROCM builds to the the nightly Build Experimental ROCM builds --- compiler/setup.py | 2 ++ runtime/setup.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/compiler/setup.py b/compiler/setup.py index 87afbc5586bb2..b49001b94135c 100644 --- a/compiler/setup.py +++ b/compiler/setup.py @@ -259,6 +259,8 @@ def prepare_installation(): "-DPython3_EXECUTABLE={}".format(sys.executable), "-DCMAKE_BUILD_TYPE={}".format(cfg), # TODO(scotttodd): include IREE_TARGET_BACKEND_WEBGPU here (and in env) + get_env_cmake_option("IREE_TARGET_BACKEND_ROCM"), + get_env_cmake_option("IREE_TARGET_BACKEND_OPENCL_SPIRV"), get_env_cmake_option("IREE_ENABLE_CPUINFO", "ON"), ] cmake_args.extend(get_cmake_version_info_args()) diff --git a/runtime/setup.py b/runtime/setup.py index a6185f84afad8..238e72cc2f83b 100644 --- a/runtime/setup.py +++ b/runtime/setup.py @@ -274,7 +274,8 @@ def build_configuration(cmake_build_dir, cmake_install_dir, extra_cmake_args=()) "IREE_HAL_DRIVER_VULKAN", "OFF" if platform.system() == "Darwin" else "ON", ), - get_env_cmake_list("IREE_EXTERNAL_HAL_DRIVERS", ""), + get_env_cmake_list("IREE_EXTERNAL_HAL_DRIVERS", + "" if platform.system() != "Linux" else "rocm"), get_env_cmake_option("IREE_ENABLE_CPUINFO", "ON"), ] + list(extra_cmake_args) add_env_cmake_setting(cmake_args, "IREE_TRACING_PROVIDER") From d612af13c3d492fd516da738b2334e49a6612753 Mon Sep 17 00:00:00 2001 From: powderluv Date: Sat, 8 Jul 2023 14:07:19 -0700 Subject: [PATCH 18/44] [BUILD] - Remove documentation build before publishing website --- .github/workflows/publish_website.yml | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/.github/workflows/publish_website.yml b/.github/workflows/publish_website.yml index 28803fb23fbf9..e41b94627abe7 100644 --- a/.github/workflows/publish_website.yml +++ b/.github/workflows/publish_website.yml @@ -44,13 +44,6 @@ jobs: with: python-version: 3.x cache: 'pip' - - id: "gcp-auth" - name: "Authenticating to Google Cloud" - uses: "google-github-actions/auth@v1" - with: - token_format: "access_token" - credentials_json: "${{ secrets.IREE_OSS_GITHUB_RUNNER_BASIC_TRUST_SERVICE_ACCOUNT_KEY }}" - create_credentials_file: false - name: Installing dependencies run: | pip install -r docs/website/requirements.txt @@ -60,14 +53,6 @@ jobs: ./build_tools/scripts/generate_release_index.py \ --repo="${GITHUB_REPOSITORY}" \ --output=docs/website/docs/pip-release-links.html - - name: Building documentation files - run: | - ./build_tools/github_actions/docker_run.sh \ - --env "IREE_CCACHE_GCP_TOKEN=${{ steps.gcp-auth.outputs.access_token }}" \ - --env "IREE_WRITE_REMOTE_CCACHE=1" \ - --env "CCACHE_NAMESPACE=gcr.io/iree-oss/base@sha256:796fb81a11ff7e7d057c93de468b74e48b6a9641aa19b7f7673c2772e8ea3b33" \ - gcr.io/iree-oss/base@sha256:796fb81a11ff7e7d057c93de468b74e48b6a9641aa19b7f7673c2772e8ea3b33 \ - ./docs/website/generate_extra_files.sh - name: Setting git config run: | git config --local user.email "iree-github-actions-bot@google.com" From 98a5a2d0eaf1105025633d677f7021221d12416b Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 20 Jul 2023 12:43:05 -0400 Subject: [PATCH 19/44] Drop CODEOWNERS to prevent sending review requests for SHARK-Runtime --- .github/CODEOWNERS | 81 ---------------------------------------------- 1 file changed, 81 deletions(-) delete mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS deleted file mode 100644 index bd825f42c62bd..0000000000000 --- a/.github/CODEOWNERS +++ /dev/null @@ -1,81 +0,0 @@ -# Codeowners for IREE Github Repository. -# The listed owners will automatically be added as reviewers to PRs that modify -# paths matching the specified patterns. -# Refer to https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners -# for syntax of this file (tl;dr: syntax is like .gitignore. Last matching rule -# takes precedence). -# Because of the precedence, rules for directories are listed topologically. -# @ghost is used to make a pattern have no owners. It is a sentinel GitHub user -# that takes the place of deleted users. - -# No global owners because we don't really want e.g. changing the root -# CMakeLists.txt file to always ping a bunch of people. - -# Third-Party Code -/.gitmodules @ScottTodd @stellaraccident -/third_party/ @ScottTodd @stellaraccident -# Except for routinely-updated submodules -/third_party/llvm-project @ghost -/third_party/llvm-project.branch-pin @ghost -/third_party/mlir-hlo @ghost - -# Bindings -/runtime/bindings/python/ @stellaraccident -/runtime/bindings/tflite/ @benvanik - -# Integrations -/integrations/ @benvanik @stellaraccident -/integrations/tensorflow/ @stellaraccident -/integrations/tensorflow/test/**/iree_tfl_tests/ @rsuderman - -# Experimental -# It's experimental, but we still don't want any old directory added here. -/experimental/ @benvanik @stellaraccident -/experimental/cpu_ukernel/ @bjacob -/experimental/cuda2/ @antiagainst -/experimental/dispatch_profiler/ @manishucsd -/experimental/rocm/ @benvanik -/experimental/web/ @ScottTodd -/experimental/webgpu/ @benvanik @ScottTodd - -# Infra Top-Level Directories -/build_tools/ @ScottTodd @pzread -/build_tools/benchmarks/ @antiagainst @pzread -/build_tools/python/ @pzread -/build_tools/python_deploy/ @stellaraccident -/build_tools/scripts/ @ScottTodd -/build_tools/third_party/ @ScottTodd @stellaraccident -/.github/ @ScottTodd - -# llvm-external-projects -/llvm-external-projects/ @stellaraccident -/llvm-external-projects/iree-dialects/ @MaheshRavishankar -/llvm-external-projects/iree-dialects/**/Dialect/LinalgExt/ @hanhanW @MaheshRavishankar -/llvm-external-projects/iree-dialects/test/iree_linalgext @hanhanW @MaheshRavishankar - -# Other Top-Level Directories -/docs/ @ScottTodd -/samples/ @ScottTodd -/tools/ @benvanik - -# Compiler -/compiler/src/iree/compiler/ @benvanik -/compiler/src/iree/compiler/Codegen/ @MaheshRavishankar -/compiler/src/iree/compiler/Codegen/Common @hanhanW @dcaballe -/compiler/src/iree/compiler/Codegen/LLVMCPU/ @dcaballe @hanhanW @MaheshRavishankar -/compiler/src/iree/compiler/Codegen/LLVMGPU/ @MaheshRavishankar -/compiler/src/iree/compiler/Codegen/SPIRV/ @antiagainst @MaheshRavishankar -/compiler/src/iree/compiler/Codegen/TransformStrategies/ @qedawkins @MaheshRavishankar -/compiler/src/iree/compiler/ConstEval/ @hanhanW @stellaraccident -/compiler/src/iree/compiler/Dialect/Flow/ @hanhanW @MaheshRavishankar -/compiler/src/iree/compiler/Dialect/Vulkan/ @antiagainst -/compiler/src/iree/compiler/GlobalOptimization/ @hanhanW -/compiler/src/iree/compiler/InputConversion/ @MaheshRavishankar @stellaraccident -/compiler/src/iree/compiler/InputConversion/MHLO @hanhanW @MaheshRavishankar @rsuderman -/compiler/src/iree/compiler/InputConversion/TOSA @MaheshRavishankar @rsuderman - -# Runtime -/runtime/src/iree/ @benvanik -/runtime/src/iree/hal/cts/ @ScottTodd -/runtime/src/iree/hal/drivers/metal/ @antiagainst -/runtime/src/iree/hal/drivers/vulkan/ @antiagainst @ScottTodd From 1e9a63dc2d1cc44618b66b784f4af02e0021ebb3 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Fri, 1 Jul 2022 15:18:44 -0700 Subject: [PATCH 20/44] [LevelZero] Add LevelZero to HAL and Codegen Use flags -DIREE_BUILD_EXPERIMENTAL_LEVEL_ZERO=ON -DLEVEL_ZERO_HEADERS_API_ROOT=/home/stanley/nod/level-zero -LevelZero HAL Driver -Addi OpenCL HAL Target compiler -Add SPIRV Codegen for Kernel capability -fix illegal pointer arithmetic on void* by Boian -Use events for command buffer execution and synchronization (#47) -Add flag for switching between Physical64 and Physical32 addressing in OpenCL -enable creation of device by UUID + ID(device handle as uintptr_t) Note that the device ID implemented like that is ephemeral and is valid only in the current IREE runtime context. If you start a new process the IDs will be different. With this change you can do $ iree-run-module --list_devices level_zero://00005100-0000-0000-0000-000000000001 $ iree-run-module --device=level_zero://00005100-0000-0000-0000-000000000001 ... -add query_memory_heaps implementation Fixes error: arithmetic on a pointer to void is a GNU extension [-Werror,-Wgnu-pointer-arith] -Supply structure type when passing such arguments in accordance with the API (#34) When calling Level Zero API functions that query information and use a struct to populate the information, the user must supply the structure type (stype) in the structure itself. The struct is both in and out argument. Fix Level Zero build Remove usage of iree_hal_command_buffer_dyn_cast. --- .gitmodules | 3 + CMakeLists.txt | 10 + build_tools/cmake/iree_check_test.cmake | 4 +- build_tools/cmake/iree_llvm.cmake | 3 + .../Codegen/Common/GPU/WorkGroupSwizzle.cpp | 1 + .../Codegen/SPIRV/ConvertToSPIRVPass.cpp | 303 ++++++++++- .../compiler/Codegen/SPIRV/KernelConfig.cpp | 4 + .../iree/compiler/Codegen/SPIRV/Passes.cpp | 15 +- .../src/iree/compiler/Codegen/SPIRV/Passes.h | 9 +- .../Target/MetalSPIRV/MetalSPIRVTarget.cpp | 4 +- .../Dialect/HAL/Target/OpenCLSPIRV/BUILD | 51 ++ .../HAL/Target/OpenCLSPIRV/CMakeLists.txt | 43 ++ .../Target/OpenCLSPIRV/OpenCLSPIRVTarget.cpp | 272 ++++++++++ .../Target/OpenCLSPIRV/OpenCLSPIRVTarget.h | 38 ++ .../Dialect/HAL/Target/OpenCLSPIRV/test/BUILD | 30 ++ .../Target/OpenCLSPIRV/test/CMakeLists.txt | 24 + .../HAL/Target/OpenCLSPIRV/test/linking.mlir | 142 +++++ .../Target/OpenCLSPIRV/test/smoketest.mlir | 39 ++ .../Target/VulkanSPIRV/VulkanSPIRVTarget.cpp | 4 +- .../HAL/Target/WebGPU/WebGPUTarget.cpp | 4 +- .../src/iree/compiler/Tools/CMakeLists.txt | 4 + .../src/iree/compiler/Tools/init_targets.cc | 7 + experimental/level_zero/CMakeLists.txt | 105 ++++ experimental/level_zero/api.h | 47 ++ experimental/level_zero/context_wrapper.h | 22 + experimental/level_zero/cts/CMakeLists.txt | 29 ++ .../level_zero/direct_command_buffer.c | 488 ++++++++++++++++++ .../level_zero/direct_command_buffer.h | 56 ++ .../level_zero/dynamic_symbol_tables.h | 90 ++++ experimental/level_zero/dynamic_symbols.c | 65 +++ experimental/level_zero/dynamic_symbols.h | 48 ++ .../level_zero/dynamic_symbols_test.cc | 85 +++ experimental/level_zero/event_semaphore.c | 114 ++++ experimental/level_zero/event_semaphore.h | 30 ++ .../level_zero/level_zero_allocator.c | 349 +++++++++++++ .../level_zero/level_zero_allocator.h | 29 ++ experimental/level_zero/level_zero_buffer.c | 140 +++++ experimental/level_zero/level_zero_buffer.h | 42 ++ experimental/level_zero/level_zero_device.c | 427 +++++++++++++++ experimental/level_zero/level_zero_device.h | 31 ++ experimental/level_zero/level_zero_driver.c | 370 +++++++++++++ experimental/level_zero/level_zero_event.c | 78 +++ experimental/level_zero/level_zero_event.h | 36 ++ experimental/level_zero/level_zero_headers.h | 17 + experimental/level_zero/native_executable.c | 186 +++++++ experimental/level_zero/native_executable.h | 45 ++ .../level_zero/nop_executable_cache.c | 96 ++++ .../level_zero/nop_executable_cache.h | 30 ++ experimental/level_zero/pipeline_layout.c | 211 ++++++++ experimental/level_zero/pipeline_layout.h | 63 +++ .../level_zero/registration/CMakeLists.txt | 21 + .../level_zero/registration/driver_module.c | 55 ++ .../level_zero/registration/driver_module.h | 24 + experimental/level_zero/status_util.c | 242 +++++++++ experimental/level_zero/status_util.h | 63 +++ experimental/level_zero/test/CMakeLists.txt | 19 + .../level_zero/test/level_zero_test.cc | 86 +++ runtime/src/iree/schemas/CMakeLists.txt | 12 + .../schemas/level_zero_executable_def.fbs | 33 ++ tests/e2e/stablehlo_ops/CMakeLists.txt | 69 +++ tests/e2e/tosa_ops/CMakeLists.txt | 50 ++ third_party/level-zero | 1 + 62 files changed, 4989 insertions(+), 29 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/BUILD create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/CMakeLists.txt create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/OpenCLSPIRVTarget.cpp create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/OpenCLSPIRVTarget.h create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/BUILD create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/CMakeLists.txt create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/linking.mlir create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/smoketest.mlir create mode 100644 experimental/level_zero/CMakeLists.txt create mode 100644 experimental/level_zero/api.h create mode 100644 experimental/level_zero/context_wrapper.h create mode 100644 experimental/level_zero/cts/CMakeLists.txt create mode 100644 experimental/level_zero/direct_command_buffer.c create mode 100644 experimental/level_zero/direct_command_buffer.h create mode 100644 experimental/level_zero/dynamic_symbol_tables.h create mode 100644 experimental/level_zero/dynamic_symbols.c create mode 100644 experimental/level_zero/dynamic_symbols.h create mode 100644 experimental/level_zero/dynamic_symbols_test.cc create mode 100644 experimental/level_zero/event_semaphore.c create mode 100644 experimental/level_zero/event_semaphore.h create mode 100644 experimental/level_zero/level_zero_allocator.c create mode 100644 experimental/level_zero/level_zero_allocator.h create mode 100644 experimental/level_zero/level_zero_buffer.c create mode 100644 experimental/level_zero/level_zero_buffer.h create mode 100644 experimental/level_zero/level_zero_device.c create mode 100644 experimental/level_zero/level_zero_device.h create mode 100644 experimental/level_zero/level_zero_driver.c create mode 100644 experimental/level_zero/level_zero_event.c create mode 100644 experimental/level_zero/level_zero_event.h create mode 100644 experimental/level_zero/level_zero_headers.h create mode 100644 experimental/level_zero/native_executable.c create mode 100644 experimental/level_zero/native_executable.h create mode 100644 experimental/level_zero/nop_executable_cache.c create mode 100644 experimental/level_zero/nop_executable_cache.h create mode 100644 experimental/level_zero/pipeline_layout.c create mode 100644 experimental/level_zero/pipeline_layout.h create mode 100644 experimental/level_zero/registration/CMakeLists.txt create mode 100644 experimental/level_zero/registration/driver_module.c create mode 100644 experimental/level_zero/registration/driver_module.h create mode 100644 experimental/level_zero/status_util.c create mode 100644 experimental/level_zero/status_util.h create mode 100644 experimental/level_zero/test/CMakeLists.txt create mode 100644 experimental/level_zero/test/level_zero_test.cc create mode 100644 runtime/src/iree/schemas/level_zero_executable_def.fbs create mode 160000 third_party/level-zero diff --git a/.gitmodules b/.gitmodules index 1e0db2c57a286..35925e5c453dc 100644 --- a/.gitmodules +++ b/.gitmodules @@ -44,3 +44,6 @@ [submodule "third_party/torch-mlir"] path = third_party/torch-mlir url = https://github.com/shark-infra/torch-mlir.git +[submodule "third_party/level-zero"] + path = third_party/level-zero + url = https://github.com/oneapi-src/level-zero diff --git a/CMakeLists.txt b/CMakeLists.txt index b5fafcf17e2a1..ff2ac3b252503 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -351,6 +351,15 @@ set(IREE_EXTERNAL_ROCM_HAL_DRIVER_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/experi set(IREE_EXTERNAL_ROCM_HAL_DRIVER_TARGET "iree::experimental::rocm::registration") set(IREE_EXTERNAL_ROCM_HAL_DRIVER_REGISTER "iree_hal_rocm_driver_module_register") +#------------------------------------------------------------------------------- +# Experimental Intel Level Zero HAL driver +#------------------------------------------------------------------------------- + +set(IREE_EXTERNAL_LEVEL_ZERO_HAL_DRIVER_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/experimental/level_zero") +set(IREE_EXTERNAL_LEVEL_ZERO_HAL_DRIVER_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/experimental/level_zero") +set(IREE_EXTERNAL_LEVEL_ZERO_HAL_DRIVER_TARGET "iree::experimental::level_zero::registration") +set(IREE_EXTERNAL_LEVEL_ZERO_HAL_DRIVER_REGISTER "iree_hal_level_zero_driver_module_register") + #------------------------------------------------------------------------------- # Experimental WebGPU HAL driver #------------------------------------------------------------------------------- @@ -390,6 +399,7 @@ cmake_dependent_option(IREE_TARGET_BACKEND_CUDA "Enables the 'cuda' compiler tar # Non-default target backends either have additional dependencies or are # experimental/niche in some fashion. cmake_dependent_option(IREE_TARGET_BACKEND_ROCM "Enables the 'rocm' compiler target backend" OFF ${IREE_BUILD_COMPILER} OFF) +cmake_dependent_option(IREE_TARGET_BACKEND_OPENCL_SPIRV "Enables the 'OpenCL-SPIRV' compiler target backend" OFF ${IREE_BUILD_COMPILER} OFF) # Disable WebGPU by default - it has complex deps and is under development. cmake_dependent_option(IREE_TARGET_BACKEND_WEBGPU "Enables the 'webgpu' compiler target backend" OFF ${IREE_BUILD_COMPILER} OFF) diff --git a/build_tools/cmake/iree_check_test.cmake b/build_tools/cmake/iree_check_test.cmake index b7d6a24b58fca..c27341a2e7677 100644 --- a/build_tools/cmake/iree_check_test.cmake +++ b/build_tools/cmake/iree_check_test.cmake @@ -191,10 +191,10 @@ function(iree_check_single_backend_test_suite) if(DEFINED _RULE_DRIVER) string(TOUPPER ${_RULE_DRIVER} _UPPERCASE_DRIVER) string(REPLACE "-" "_" _NORMALIZED_DRIVER ${_UPPERCASE_DRIVER}) - if((NOT DEFINED IREE_HAL_DRIVER_${_NORMALIZED_DRIVER}) AND (NOT ${_NORMALIZED_DRIVER} IN_LIST _NORMALIZED_EXTERNAL_DRIVERS)) + if((NOT DEFINED IREE_HAL_DRIVER_${_NORMALIZED_DRIVER}) AND (NOT ${_NORMALIZED_DRIVER} IN_LIST _NORMALIZED_EXTERNAL_DRIVERS) AND (NOT ${_RULE_DRIVER} STREQUAL "level_zero")) message(SEND_ERROR "Unknown driver '${_RULE_DRIVER}'. Check IREE_HAL_DRIVER_*/IREE_EXTERNAL_HAL_DRIVERS options.") endif() - if((NOT IREE_HAL_DRIVER_${_NORMALIZED_DRIVER}) AND (NOT IREE_EXTERNAL_${_NORMALIZED_DRIVER}_HAL_DRIVER_FOUND)) + if((NOT IREE_HAL_DRIVER_${_NORMALIZED_DRIVER}) AND (NOT IREE_EXTERNAL_${_NORMALIZED_DRIVER}_HAL_DRIVER_FOUND) AND (NOT ${_RULE_DRIVER} STREQUAL "level_zero")) return() endif() endif() diff --git a/build_tools/cmake/iree_llvm.cmake b/build_tools/cmake/iree_llvm.cmake index 10689403e9a50..88e2a59189057 100644 --- a/build_tools/cmake/iree_llvm.cmake +++ b/build_tools/cmake/iree_llvm.cmake @@ -177,6 +177,9 @@ macro(iree_llvm_set_bundled_cmake_options) if(IREE_TARGET_BACKEND_METAL_SPIRV) message(STATUS " - metal-spirv") endif() + if(IREE_TARGET_BACKEND_OPENCL_SPIRV) + message(STATUS " - opencl-spirv") + endif() if(IREE_TARGET_BACKEND_ROCM) message(STATUS " - rocm") list(APPEND LLVM_TARGETS_TO_BUILD AMDGPU) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/WorkGroupSwizzle.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/WorkGroupSwizzle.cpp index 02e58e13807ac..f32ed3ef65e9f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/WorkGroupSwizzle.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/WorkGroupSwizzle.cpp @@ -8,6 +8,7 @@ #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" namespace mlir { namespace iree_compiler { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp index 8064dce858850..765bd211fd3d4 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp @@ -167,9 +167,40 @@ InterfaceResourceMap createResourceVariables(mlir::ModuleOp module) { //===----------------------------------------------------------------------===// namespace { + +// Helper type and function to get kernel arguments. +using SetBinding = std::pair; +/// Convention with the HAL side to pass kernel arguments. +/// The bindings are ordered based on binding set and binding index then +/// compressed and mapped to dense set of arguments. +/// This function looks at the symbols and return the mapping between +/// InterfaceBindingOp and kernel argument index. +/// For instance if the kernel has (set, bindings) A(0, 1), B(1, 5), C(0, 6) it +/// will return the mapping [A, 0], [C, 1], [B, 2] +static llvm::SmallDenseMap +getKernelArgMapping(Operation *funcOp) { + llvm::SetVector usedBindingSet; + funcOp->walk([&](IREE::HAL::InterfaceBindingSubspanOp subspanOp) { + usedBindingSet.insert( + SetBinding(subspanOp.getSet(), subspanOp.getBinding())); + }); + auto sparseBindings = usedBindingSet.takeVector(); + std::sort(sparseBindings.begin(), sparseBindings.end(), + [](SetBinding lhs, SetBinding rhs) { + if (lhs.first == rhs.first) + return lhs.second.ult(rhs.second); + return lhs.first.ult(rhs.first); + }); + llvm::SmallDenseMap mapBindingArgIndex; + for (auto binding : llvm::enumerate(sparseBindings)) { + mapBindingArgIndex[binding.value()] = binding.index(); + } + return mapBindingArgIndex; +} + /// A pattern to convert hal.interface.constant.load into a sequence of SPIR-V /// ops to load from a global variable representing the push constant storage. -struct HALInterfaceLoadConstantConverter final +struct HALInterfaceLoadConstantToAccessChainLoadConverter final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -199,6 +230,30 @@ struct HALInterfaceLoadConstantConverter final } }; +/// A pattern to convert hal.interface.constant.load into the pointer from the +/// argument. This pass is to convert the region to mimic OpenCL styled kernels. +struct HALInterfaceLoadConstantToArgPointerConverter final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(IREE::HAL::InterfaceConstantLoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Bail until nested under an SPVFuncOp. + auto spirvFuncOp = loadOp->getParentOfType(); + if (!spirvFuncOp) + return failure(); + assert(spirvFuncOp.getNumArguments() > 0); + + auto argMapping = getKernelArgMapping(spirvFuncOp); + auto spirvBufferArg = spirvFuncOp.getArgument( + argMapping.size() + loadOp.getIndex().getZExtValue()); + assert(spirvBufferArg.getType().isInteger(32)); + rewriter.replaceOp(loadOp, spirvBufferArg); + return success(); + } +}; + /// A pattern to convert hal.interface.workgroup.id/count into corresponding /// SPIR-V Builtin ops. template @@ -233,9 +288,9 @@ struct HALInterfaceWorkgroupIdAndCountConverter final /// A pattern to convert hal.interface.binding.subspan into a sequence of SPIR-V /// ops to get the address to a global variable representing the resource /// buffer. -struct HALInterfaceBindingSubspanConverter final +struct HALInterfaceBindingSubspanToGlobalVarAddressConverter final : public OpConversionPattern { - HALInterfaceBindingSubspanConverter( + HALInterfaceBindingSubspanToGlobalVarAddressConverter( TypeConverter &typeConverter, MLIRContext *context, const InterfaceResourceMap &interfaceToResourceVars, PatternBenefit benefit = 1) @@ -277,6 +332,182 @@ struct HALInterfaceBindingSubspanConverter final const InterfaceResourceMap &interfaceToResourceVars; }; +/// A pattern to convert hal.interface.binding.subspan into the pointer from the +/// argument. This pass is to convert the region to mimic OpenCL styled kernels. +struct HALInterfaceBindingSubspanToArgPointerConverter final + : public OpConversionPattern { + HALInterfaceBindingSubspanToArgPointerConverter(TypeConverter &typeConverter, + MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(IREE::HAL::InterfaceBindingSubspanOp subspanOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (subspanOp.use_empty()) { + rewriter.eraseOp(subspanOp); + return success(); + } + + Type resultType = subspanOp.getOperation()->getResult(0).getType(); + Type convertedType = this->getTypeConverter()->convertType(resultType); + if (!convertedType) { + return subspanOp.emitError() + << "failed to convert SPIR-V type: " << resultType; + } + + // Bail until nested under an SPV::FuncOp. + auto spirvFuncOp = + subspanOp.getOperation()->getParentOfType(); + auto argMapping = getKernelArgMapping(spirvFuncOp); + size_t argIndex = argMapping.lookup( + SetBinding(subspanOp.getSet(), subspanOp.getBinding())); + if (argIndex >= argMapping.size()) + return failure(); + if (argIndex >= spirvFuncOp.getNumArguments()) + return failure(); + auto argValue = spirvFuncOp.getArgument(argIndex); + + // Same set-binding pair can contain different data with different types. + // In this case, we need to apply bitcasting. + spirv::PointerType argPtrType = + argValue.getType().dyn_cast(); + if (!argPtrType) { + return subspanOp.emitError() + << "Got something other than spv.ptr to replace subspan in " + "capability::Kernel, but got: " + << argValue.getType() << " instead."; + } + auto memrefType = subspanOp.getType().cast(); + Type subspanElType = memrefType.getElementType(); + auto argElType = argPtrType.getPointeeType(); + Value dataPtr = argValue; + // Bitcast to the different data type if necessary. + if (argElType != subspanElType) { + auto dataPtrType = spirv::PointerType::get( + subspanElType, spirv::StorageClass::CrossWorkgroup); + dataPtr = rewriter.create(subspanOp.getLoc(), + dataPtrType, dataPtr); + } + + // Handling 0-D memref's by typecasting to spirv::Array. + if (memrefType.getRank() == 0) { + dataPtr.setType( + spirv::PointerType::get(spirv::ArrayType::get(subspanElType, 1), + spirv::StorageClass::CrossWorkgroup)); + } + + // Convert a dynamic shaped storage buffer into an spirv::Array of known + // dimension (obtained using attribute) + auto attr = subspanOp.getDescriptorTypeAttr() + .dyn_cast_or_null(); + if (memrefType.hasStaticShape() && memrefType.getRank() == 1 && + attr.getValue() == IREE::HAL::DescriptorType::StorageBuffer) { + dataPtr.setType(spirv::PointerType::get( + spirv::ArrayType::get(subspanElType, memrefType.getDimSize(0)), + spirv::StorageClass::CrossWorkgroup)); + } + + // Add the byte offset. + if (adaptor.getByteOffset()) { + auto offsetOp = + dyn_cast(adaptor.getByteOffset().getDefiningOp()); + if (!offsetOp) { + return subspanOp.emitError() + << "Found offset, but offset defining Op is expected to be " + "spv.constant, but is not."; + } + auto offsetVal = offsetOp.getValue().dyn_cast().getInt(); + if (offsetVal) { + return subspanOp.emitError() + << "Found offset, offset expected as int, but found: " + << offsetOp.getValue() << " instead."; + } + // Check that there is non-zero offset and add the byte offset if + // necessary. + if (offsetVal != 0) { + SmallVector emptyIndices; + dataPtr = rewriter.create( + subspanOp.getLoc(), dataPtr, adaptor.getByteOffset(), emptyIndices); + } + } + rewriter.replaceOp(subspanOp, dataPtr); + return success(); + } +}; + +struct FuncOpToSPVConverter final : public OpConversionPattern { + FuncOpToSPVConverter(TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FunctionType fnType = funcOp.getFunctionType(); + (void)fnType; + if (!funcOp.isPublic()) + return failure(); + + // illegal FuncOp must have 0 inputs. + assert(fnType.getNumInputs() == 0 && fnType.getNumResults() == 0); + + TypeConverter::SignatureConversion signatureConverter(/*numOrigInputs=*/0); + auto argMapping = getKernelArgMapping(funcOp); + // There may be dead symbols, we pick i32 pointer as default argument type. + SmallVector spirvInputTypes( + argMapping.size(), + spirv::PointerType::get(rewriter.getI32Type(), + spirv::StorageClass::CrossWorkgroup)); + funcOp.walk([&](IREE::HAL::InterfaceBindingSubspanOp subspanOp) { + auto memrefType = subspanOp.getType().cast(); + Type elType = memrefType.getElementType(); + Type inputConvertedSpirvType = + spirv::PointerType::get(elType, spirv::StorageClass::CrossWorkgroup); + spirvInputTypes[argMapping[SetBinding(subspanOp.getSet(), + subspanOp.getBinding())]] = + inputConvertedSpirvType; + }); + // As a convention with HAL, push constants are appended as kernel arguments + // after all the binding inputs. + uint64_t numConstants = 0; + funcOp.walk([&](IREE::HAL::InterfaceConstantLoadOp constantOp) { + numConstants = + std::max(constantOp.getIndex().getZExtValue() + 1, numConstants); + }); + spirvInputTypes.resize(argMapping.size() + numConstants, + rewriter.getI32Type()); + if (!spirvInputTypes.empty()) + signatureConverter.addInputs(spirvInputTypes); + + auto spirvFuncType = + FunctionType::get(rewriter.getContext(), spirvInputTypes, + /*resultTypes=*/{}); + auto spirvFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), spirvFuncType, + spirv::FunctionControl::None); + + // Copy over all attributes other than the function name and type. + for (const auto &namedAttr : funcOp->getAttrs()) { + if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() && + namedAttr.getName() != SymbolTable::getSymbolAttrName()) + spirvFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); + } + + // Copy all of funcOp's operations into spirvFuncOp's body and perform + // region type conversion. + rewriter.inlineRegionBefore(funcOp.getBody(), spirvFuncOp.getBody(), + spirvFuncOp.end()); + if (failed(rewriter.convertRegionTypes( + &spirvFuncOp.getBody(), *typeConverter, &signatureConverter))) { + return failure(); + } + rewriter.eraseOp(funcOp); + return success(); + } +}; + /// Pattern to lower operations that become a no-ops at this level. template struct FoldAsNoOp final : public OpConversionPattern { @@ -340,8 +571,10 @@ class ConvertToSPIRVPass : public ConvertToSPIRVBase { registry.insert(); } - explicit ConvertToSPIRVPass(bool enableFastMath, unsigned indexBits) - : enableFastMath(enableFastMath), indexBits(indexBits) {} + explicit ConvertToSPIRVPass(bool enableFastMath, unsigned indexBits, + spirv::AddressingModel addressingModel) + : enableFastMath(enableFastMath), indexBits(indexBits), + addressingModel(addressingModel) {} LogicalResult initializeOptions(StringRef options) override { if (failed(Pass::initializeOptions(options))) @@ -360,6 +593,8 @@ class ConvertToSPIRVPass : public ConvertToSPIRVBase { bool enableFastMath; // Use 64 bits for index widths. unsigned indexBits; + // Addressing model to use. + spirv::AddressingModel addressingModel; }; } // namespace @@ -426,6 +661,11 @@ void ConvertToSPIRVPass::runOnOperation() { spirv::TargetEnvAttr targetAttr = getSPIRVTargetEnvAttr(moduleOp); moduleOp->setAttr(spirv::getTargetEnvAttrName(), targetAttr); + if (addressingModel == spirv::AddressingModel::Physical32) + indexBits = 32; + else if (addressingModel == spirv::AddressingModel::Physical64) + indexBits = 64; + if (indexBits != 32 && indexBits != 64) { moduleOp.emitOpError( "Only 32-bit or 64-bit indices are supported for SPIR-V"); @@ -453,6 +693,13 @@ void ConvertToSPIRVPass::runOnOperation() { RewritePatternSet patterns(&getContext()); ScfToSPIRVContext scfToSPIRVContext; + bool hasKernelCapabilty = false; + for (auto capability : targetAttr.getCapabilities()) { + if (capability == spirv::Capability::Kernel) { + hasKernelCapabilty = true; + } + } + // Pull in GPU patterns to convert processor ID ops and loop ops. populateGPUToSPIRVPatterns(typeConverter, patterns); populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter, @@ -462,6 +709,7 @@ void ConvertToSPIRVPass::runOnOperation() { populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns); // Pull in MemRef patterns to convert load/store ops. + populateMemRefToSPIRVPatterns(typeConverter, patterns); // Pull in standard/math patterns to convert arithmetic ops and others. @@ -488,19 +736,30 @@ void ConvertToSPIRVPass::runOnOperation() { // Add IREE HAL interface op conversions. patterns.insert< - HALInterfaceLoadConstantConverter, HALInterfaceWorkgroupIdAndCountConverter< IREE::HAL::InterfaceWorkgroupIDOp, spirv::BuiltIn::WorkgroupId>, HALInterfaceWorkgroupIdAndCountConverter< IREE::HAL::InterfaceWorkgroupCountOp, spirv::BuiltIn::NumWorkgroups>>( typeConverter, context); - // Performs a prelimiary step to analyze all hal.interface.binding.subspan ops - // and create spirv.GlobalVariables. - auto interfaceToResourceVars = createResourceVariables(moduleOp); - // For using use them in conversion. - patterns.insert(typeConverter, context, - interfaceToResourceVars); + // Interface-Resource Map needs to be initialized in main region to prevent + // segfault. + InterfaceResourceMap interfaceToResourceVars; + if (hasKernelCapabilty) { + patterns.insert( + typeConverter, context); + } else { + patterns.insert( + typeConverter, context); + // Performs a prelimiary step to analyze all hal.interface.binding.subspan + // ops and create spirv.GlobalVariables. + interfaceToResourceVars = createResourceVariables(moduleOp); + // For using use them in conversion. + patterns.insert( + typeConverter, context, interfaceToResourceVars); + } /// Fold certain operations as no-ops: /// - linalg.reshape becomes a no-op since all memrefs are linearized in @@ -533,10 +792,20 @@ void ConvertToSPIRVPass::runOnOperation() { } // Collect all SPIR-V ops into a spirv.module. + spirv::MemoryModel memoryModel = spirv::MemoryModel::GLSL450; + if (hasKernelCapabilty) { + if (addressingModel != spirv::AddressingModel::Physical64 && + addressingModel != spirv::AddressingModel::Physical32) { + moduleOp.emitOpError( + "Only Physical32 or Physical64 addressing models are supported for " + "OpenCL SPIR-V"); + return signalPassFailure(); + } + memoryModel = spirv::MemoryModel::OpenCL; + } auto builder = OpBuilder::atBlockBegin(moduleOp.getBody()); auto spvModule = builder.create( - moduleOp.getLoc(), spirv::AddressingModel::Logical, - spirv::MemoryModel::GLSL450); + moduleOp.getLoc(), addressingModel, memoryModel); Block *body = spvModule.getBody(); Dialect *spvDialect = spvModule->getDialect(); for (Operation &op : llvm::make_early_inc_range(*moduleOp.getBody())) { @@ -553,8 +822,10 @@ void ConvertToSPIRVPass::runOnOperation() { //===----------------------------------------------------------------------===// std::unique_ptr> -createConvertToSPIRVPass(bool enableFastMath, unsigned indexBits) { - return std::make_unique(enableFastMath, indexBits); +createConvertToSPIRVPass(bool enableFastMath, unsigned indexBits, + spirv::AddressingModel addressingModel) { + return std::make_unique(enableFastMath, indexBits, + addressingModel); } } // namespace iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index 9c586c6e6eb72..1579eabf27dc3 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -1173,6 +1173,10 @@ static LogicalResult setWinogradOpConfig(spirv::ResourceLimitsAttr limits, /// Set the configuration for reductions that can be mapped to warp reductions. static LogicalResult setReductionConfig(const spirv::TargetEnv &targetEnv, linalg::GenericOp op) { + // TODO: Fix/support Warp Reduction for LevelZero. + if (targetEnv.allows(spirv::Capability::Kernel)) + return failure(); + LLVM_DEBUG(llvm::dbgs() << "trying to deduce config as reduction...\n"); auto funcOp = op->getParentOfType(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp index 129e4298673e0..c8b21733a7dc1 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp @@ -219,7 +219,8 @@ static void addMemRefLoweringPasses(OpPassManager &pm) { } /// Adds passes to perform the final SPIR-V conversion. -static void addSPIRVLoweringPasses(OpPassManager &pm, bool enableFastMath) { +static void addSPIRVLoweringPasses(OpPassManager &pm, bool enableFastMath, + spirv::AddressingModel addressingModel) { pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); @@ -239,7 +240,8 @@ static void addSPIRVLoweringPasses(OpPassManager &pm, bool enableFastMath) { pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - pm.addPass(createConvertToSPIRVPass(enableFastMath, clSPIRVIndexingBits)); + pm.addPass(createConvertToSPIRVPass(enableFastMath, clSPIRVIndexingBits, + addressingModel)); auto getTargetEnv = [](spirv::ModuleOp moduleOp) { return getSPIRVTargetEnvAttr(moduleOp); @@ -636,7 +638,8 @@ void addSPIRVTransformDialectPassPipeline(OpPassManager &pm) { // Entry Point //===----------------------------------------------------------------------===// -void buildSPIRVCodegenPassPipeline(OpPassManager &pm, bool enableFastMath) { +void buildSPIRVCodegenPassPipeline(OpPassManager &pm, bool enableFastMath, + spirv::AddressingModel addressingModel) { addCommonTargetExecutablePreprocessingPasses(pm.nest()); auto &nestedModulePM = pm.nest(); nestedModulePM.addNestedPass( @@ -644,7 +647,7 @@ void buildSPIRVCodegenPassPipeline(OpPassManager &pm, bool enableFastMath) { pm.addPass(createSPIRVLowerExecutableTargetPass()); addMemRefLoweringPasses(pm.nest()); - addSPIRVLoweringPasses(pm.nest(), enableFastMath); + addSPIRVLoweringPasses(pm.nest(), enableFastMath, addressingModel); LLVM_DEBUG({ llvm::dbgs() << "Using SPIR-V pass pipeline:\n"; @@ -670,7 +673,9 @@ void registerCodegenSPIRVPasses() { "iree-codegen-linalg-to-spirv-pipeline", "Runs the progressive lowering pipeline from linalg to SPIR-V", [](OpPassManager &passManager) { - buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/false); + buildSPIRVCodegenPassPipeline( + passManager, /*enableFastMath=*/false, + /*addressingModel=*/spirv::AddressingModel::Logical); }); } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h index faf2d9fe1a6e6..9a3cd006e7f08 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h @@ -13,6 +13,7 @@ #define IREE_COMPILER_CODEGEN_SPIRV_PASSES_H_ #include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -53,15 +54,17 @@ void addSPIRVWinogradVectorizePassPipeline(OpPassManager &pm); /// Populates passes needed to lower linalg/arith/math ops to SPIR-V ops via /// the structured ops path. The pass manager `pm` here operate on the module /// within the IREE::HAL::ExecutableOp. -void buildSPIRVCodegenPassPipeline(OpPassManager &pm, bool enableFastMath); +void buildSPIRVCodegenPassPipeline(OpPassManager &pm, bool enableFastMath, + spirv::AddressingModel addressingModel); /// Pass to perform the final conversion to SPIR-V dialect. /// /// This pass converts remaining interface ops into SPIR-V global variables, /// GPU processor ID ops into SPIR-V global variables, loop/standard ops into /// corresponding SPIR-V ops. -std::unique_ptr> -createConvertToSPIRVPass(bool enableFastMath = false, unsigned indexWidth = 32); +std::unique_ptr> createConvertToSPIRVPass( + bool enableFastMath = false, unsigned indexWidth = 32, + spirv::AddressingModel addressingModel = spirv::AddressingModel::Logical); /// Annotates the innermost Winograd loops with the spirv distribute attribute. std::unique_ptr> diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp index dfb91f7958061..389d77bf506be 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp @@ -134,7 +134,9 @@ class MetalSPIRVTargetBackend : public TargetBackend { if (variantOp.isExternal()) return; - buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/false); + buildSPIRVCodegenPassPipeline( + passManager, /*enableFastMath=*/false, + /*addressingModel=*/spirv::AddressingModel::Logical); } LogicalResult serializeExecutable(const SerializationOptions &options, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/BUILD new file mode 100644 index 0000000000000..1929ec15221d2 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/BUILD @@ -0,0 +1,51 @@ +# Copyright 2019 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content", "iree_compiler_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_cmake_extra_content( + content = """ +if(NOT IREE_TARGET_BACKEND_VULKAN_SPIRV) + return() +endif() +""", +) + +iree_compiler_cc_library( + name = "VulkanSPIRV", + srcs = [ + "VulkanSPIRVTarget.cpp", + ], + hdrs = [ + "VulkanSPIRVTarget.h", + ], + deps = [ + "//compiler/src/iree/compiler/Codegen:PassHeaders", + "//compiler/src/iree/compiler/Codegen/Common", + "//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect", + "//compiler/src/iree/compiler/Codegen/SPIRV", + "//compiler/src/iree/compiler/Codegen/Utils", + "//compiler/src/iree/compiler/Dialect/HAL/Target", + "//compiler/src/iree/compiler/Dialect/Vulkan/IR", + "//compiler/src/iree/compiler/Dialect/Vulkan/Utils", + "//compiler/src/iree/compiler/Utils", + "//runtime/src/iree/schemas:spirv_executable_def_c_fbs", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SPIRVDialect", + "@llvm-project//mlir:SPIRVModuleCombiner", + "@llvm-project//mlir:SPIRVSerialization", + "@llvm-project//mlir:Support", + ], +) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/CMakeLists.txt new file mode 100644 index 0000000000000..a8d46fbe16351 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/CMakeLists.txt @@ -0,0 +1,43 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +if(NOT IREE_TARGET_BACKEND_OPENCL_SPIRV) + return() +endif() + +iree_add_all_subdirs() + +iree_cc_library( + NAME + OpenCLSPIRV + HDRS + "OpenCLSPIRVTarget.h" + SRCS + "OpenCLSPIRVTarget.cpp" + DEPS + LLVMSupport + MLIRGPUDialect + MLIRIR + MLIRParser + MLIRSPIRVDialect + MLIRSPIRVModuleCombiner + MLIRSPIRVSerialization + MLIRSupport + iree::compiler::Codegen::Common + iree::compiler::Codegen::Dialect::IREECodegenDialect + iree::compiler::Codegen::SPIRV + iree::compiler::Codegen::Utils + iree::compiler::Dialect::HAL::Target + iree::compiler::Utils + iree::schemas::level_zero_executable_def_c_fbs + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/OpenCLSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/OpenCLSPIRVTarget.cpp new file mode 100644 index 0000000000000..3d755d833cefc --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/OpenCLSPIRVTarget.cpp @@ -0,0 +1,272 @@ +// Copyright 2019 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/OpenCLSPIRVTarget.h" + +#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h" +#include "iree/compiler/Codegen/SPIRV/Passes.h" +#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" +#include "iree/compiler/Utils/FlatbufferUtils.h" +#include "iree/schemas/level_zero_executable_def_builder.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Target/SPIRV/Serialization.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { + +OpenCLSPIRVTargetOptions getOpenCLSPIRVTargetOptionsFromFlags() { + static llvm::cl::opt clOpenCLTargetTriple( + "iree-opencl-target-triple", llvm::cl::desc("OpenCL target triple"), + llvm::cl::init("spir-unknown-unknown")); + + static llvm::cl::opt clOpenCLUsePhysical32( + "iree-opencl-physical32-addressing", + llvm::cl::desc("Use Physical32 addressing with OpenCL"), + llvm::cl::init(false)); + + OpenCLSPIRVTargetOptions targetOptions; + targetOptions.openCLTargetTriple = clOpenCLTargetTriple; + targetOptions.openCLUsePhysical32 = clOpenCLUsePhysical32; + + return targetOptions; +} + +// Returns the Vulkan target environment for conversion. +static spirv::TargetEnvAttr +getSPIRVTargetEnv(const std::string &openCLTargetTriple, MLIRContext *context) { + // if (!openCLTargetTriple.empty()) { + // return convertTargetEnv( + // Vulkan::getTargetEnvForTriple(context, openCLTargetTriple)); + // } + auto triple = spirv::VerCapExtAttr::get( + spirv::Version::V_1_4, + {spirv::Capability::Kernel, + spirv::Capability::Addresses, + spirv::Capability::SubgroupDispatch, + spirv::Capability::Float16Buffer, + spirv::Capability::Int16, + spirv::Capability::Int8, + spirv::Capability::Vector16, + spirv::Capability::GenericPointer, + spirv::Capability::Groups, + spirv::Capability::ImageBasic, + spirv::Capability::Float16, + spirv::Capability::Linkage, + spirv::Capability::Int64Atomics, + spirv::Capability::Int64, + spirv::Capability::Float64, + spirv::Capability::GroupNonUniform, + spirv::Capability::GroupNonUniformVote, + spirv::Capability::GroupNonUniformBallot, + spirv::Capability::GroupNonUniformArithmetic, + spirv::Capability::GroupNonUniformShuffle, + spirv::Capability::GroupNonUniformShuffleRelative, + spirv::Capability::GroupNonUniformClustered, + spirv::Capability::AtomicFloat16AddEXT, + spirv::Capability::AtomicFloat32AddEXT, + spirv::Capability::AtomicFloat64AddEXT, + spirv::Capability::LiteralSampler, + spirv::Capability::Sampled1D, + spirv::Capability::Image1D, + spirv::Capability::SampledBuffer, + spirv::Capability::ImageBuffer, + spirv::Capability::ImageReadWrite}, + {spirv::Extension::SPV_INTEL_subgroups, + spirv::Extension::SPV_EXT_shader_atomic_float_add, + spirv::Extension::SPV_EXT_shader_atomic_float16_add, + spirv::Extension::SPV_EXT_shader_atomic_float_min_max, + spirv::Extension::SPV_KHR_linkonce_odr}, + context); + return spirv::TargetEnvAttr::get( + triple, spirv::getDefaultResourceLimits(context), + spirv::ClientAPI::Unknown, spirv::Vendor::Unknown, + spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID); + return {}; +} + +class OpenCLSPIRVTargetBackend : public TargetBackend { +public: + OpenCLSPIRVTargetBackend(OpenCLSPIRVTargetOptions options) + : options_(std::move(options)) {} + + std::string name() const override { return "opencl"; } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + IREE::HAL::DeviceTargetAttr + getDefaultDeviceTarget(MLIRContext *context) const override { + Builder b(context); + SmallVector configItems; + + // Indicates that the runtime HAL driver operates only in the legacy + // synchronous mode. + configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr()); + + configItems.emplace_back(b.getStringAttr("executable_targets"), + getExecutableTargets(context)); + + auto configAttr = b.getDictionaryAttr(configItems); + return IREE::HAL::DeviceTargetAttr::get( + context, b.getStringAttr(deviceID()), configAttr); + } + + void buildTranslationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp, + OpPassManager &passManager) override { + // For now we disable translation if the variant has external object files. + // We could instead perform linking with those objects (if they're .spv + // files we could use spirv-link or import them into MLIR and merge here). + if (variantOp.isExternal()) + return; + + spirv::AddressingModel addressingModel = spirv::AddressingModel::Physical64; + if (options_.openCLUsePhysical32) + addressingModel = spirv::AddressingModel::Physical32; + + buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/false, + addressingModel); + } + + LogicalResult serializeExecutable(const SerializationOptions &options, + IREE::HAL::ExecutableVariantOp variantOp, + OpBuilder &executableBuilder) override { + ModuleOp innerModuleOp = variantOp.getInnerModule(); + auto spirvModuleOps = innerModuleOp.getOps(); + if (!llvm::hasSingleElement(spirvModuleOps)) { + return variantOp.emitError() + << "should only contain exactly one spv.module op"; + } + auto spvModuleOp = *spirvModuleOps.begin(); + + FlatbufferBuilder builder; + iree_LEVEL_ZEROExecutableDef_start_as_root(builder); + + // Serialize the spirv::ModuleOp into the binary that we will embed in the + // final FlatBuffer. + SmallVector spvBinary; + if (failed(spirv::serialize(spvModuleOp, spvBinary)) || spvBinary.empty()) { + return variantOp.emitError() << "failed to serialize spv.module"; + } + + // if (!options.dumpBinariesPath.empty()) { + // dumpDataToPath(options.dumpBinariesPath, + // options.dumpBaseName, + // variantOp.getName(), ".spv", spvBinary); + // } + + auto spvCodeRef = flatbuffers_uint32_vec_create(builder, spvBinary.data(), + spvBinary.size()); + + // The sequencer and runtime use ordinals instead of names. We provide the + // list of entry point names here that are then passed in + // zeModuleCreate. + SmallVector entryPointNames; + std::vector> workgroupSizes; + spvModuleOp.walk([&](spirv::ExecutionModeOp executionModelOp) { + entryPointNames.push_back(executionModelOp.getFn()); + ArrayAttr workGroupSizeAttr = executionModelOp.getValues(); + assert(workGroupSizeAttr.size() == 3 && + "workgroup size is expected to be 3"); + workgroupSizes.push_back( + {int(workGroupSizeAttr[0].dyn_cast().getInt()), + int(workGroupSizeAttr[1].dyn_cast().getInt()), + int(workGroupSizeAttr[2].dyn_cast().getInt())}); + }); + // if (!options.dumpBinariesPath.empty()) { + // dumpDataToPath("/tmp", entryPointNames[0], + // variantOp.getName(), ".spv", spvBinary); + // } + + auto entryPointsRef = builder.createStringVec(entryPointNames); + iree_LEVEL_ZEROBlockSizeDef_vec_start(builder); + auto blockSizes = workgroupSizes.begin(); + for (int i = 0, e = entryPointNames.size(); i < e; ++i) { + iree_LEVEL_ZEROBlockSizeDef_vec_push_create( + builder, (*blockSizes)[0], (*blockSizes)[1], (*blockSizes)[2]); + ++blockSizes; + } + auto blockSizesRef = iree_LEVEL_ZEROBlockSizeDef_vec_end(builder); + + iree_LEVEL_ZEROExecutableDef_entry_points_add(builder, entryPointsRef); + iree_LEVEL_ZEROExecutableDef_block_sizes_add(builder, blockSizesRef); + iree_LEVEL_ZEROExecutableDef_level_zero_image_add(builder, spvCodeRef); + iree_LEVEL_ZEROExecutableDef_end_as_root(builder); + + // Add the binary data to the target executable. + auto binaryOp = executableBuilder.create( + variantOp.getLoc(), variantOp.getSymName(), + variantOp.getTarget().getFormat(), + builder.getBufferAttr(executableBuilder.getContext())); + binaryOp.setMimeTypeAttr( + executableBuilder.getStringAttr("application/x-flatbuffers")); + + return success(); + } + +private: + ArrayAttr getExecutableTargets(MLIRContext *context) const { + SmallVector targetAttrs; + // If we had multiple target environments we would generate one target attr + // per environment, with each setting its own environment attribute. + targetAttrs.push_back(getExecutableTarget( + context, getSPIRVTargetEnv(options_.openCLTargetTriple, context))); + return ArrayAttr::get(context, targetAttrs); + } + + IREE::HAL::ExecutableTargetAttr + getExecutableTarget(MLIRContext *context, + spirv::TargetEnvAttr targetEnv) const { + Builder b(context); + SmallVector configItems; + + configItems.emplace_back(b.getStringAttr(spirv::getTargetEnvAttrName()), + targetEnv); + + auto configAttr = b.getDictionaryAttr(configItems); + return IREE::HAL::ExecutableTargetAttr::get( + context, b.getStringAttr("opencl"), b.getStringAttr("opencl-spirv-fb"), + configAttr); + } + + OpenCLSPIRVTargetOptions options_; +}; + +void registerOpenCLSPIRVTargetBackends( + std::function queryOptions) { + getOpenCLSPIRVTargetOptionsFromFlags(); + auto backendFactory = [=]() { + return std::make_shared(queryOptions()); + }; + // #hal.device.target<"opencl", ... + static TargetBackendRegistration registration0("opencl", backendFactory); + // #hal.executable.target<"opencl-spirv", ... + static TargetBackendRegistration registration1("opencl-spirv", + backendFactory); +} + +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/OpenCLSPIRVTarget.h b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/OpenCLSPIRVTarget.h new file mode 100644 index 0000000000000..ea0ae18f50e92 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/OpenCLSPIRVTarget.h @@ -0,0 +1,38 @@ +// Copyright 2019 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_OPENCLSPIRV_OPENCLSPIRVTARGET_H_ +#define IREE_COMPILER_DIALECT_HAL_TARGET_OPENCLSPIRV_OPENCLSPIRVTARGET_H_ + +#include +#include + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { + +// Options controlling the SPIR-V translation. +struct OpenCLSPIRVTargetOptions { + // OpenCL target triple. + std::string openCLTargetTriple; + bool openCLUsePhysical32; +}; + +// Returns a OpenCLSPIRVTargetOptions struct initialized with OpenCL/SPIR-V +// related command-line flags. +OpenCLSPIRVTargetOptions getOpenCLSPIRVTargetOptionsFromFlags(); + +// Registers the OpenCL/SPIR-V backends. +void registerOpenCLSPIRVTargetBackends( + std::function queryOptions); + +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_DIALECT_HAL_TARGET_OPENCLSPIRV_OPENCLSPIRVTARGET_H_ diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/BUILD new file mode 100644 index 0000000000000..f837586f53121 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/BUILD @@ -0,0 +1,30 @@ +# Copyright 2020 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_lit_test_suite( + name = "lit", + srcs = enforce_glob( + [ + "linking.mlir", + "smoketest.mlir", + ], + include = ["*.mlir"], + ), + cfg = "//compiler:lit.cfg.py", + tools = [ + "//tools:iree-opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/CMakeLists.txt new file mode 100644 index 0000000000000..3ef7672b498d4 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/CMakeLists.txt @@ -0,0 +1,24 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + SRCS + "linking.mlir" + "smoketest.mlir" + TOOLS + FileCheck + iree-opt +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/linking.mlir b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/linking.mlir new file mode 100644 index 0000000000000..1c31fc993e8ae --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/linking.mlir @@ -0,0 +1,142 @@ +// TODO(antiagainst): Re-enable SPIR-V linking once the tensorflow integration +// crash is fixed. +// RUN-disabled: iree-opt --split-input-file --iree-hal-link-target-executables='target=vulkan-spirv' %s | FileCheck %s +// RUN: iree-opt --split-input-file %s + +#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan", "vulkan-spirv-fb"> + +#executable_layout_0 = #hal.executable.layout, + #hal.descriptor_set.binding<1, storage_buffer> + ]> +]> + +#executable_layout_1 = #hal.executable.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> + +hal.executable private @call_dispatch_0 { + hal.executable.variant @vulkan_spirv_fb, target = #executable_target_vulkan_spirv_fb { + hal.executable.export @call_dispatch_0 ordinal(0) layout(#executable_layout_0) + builtin.module { + spv.module Logical GLSL450 requires #spv.vce { + spv.func @call_dispatch_0() "None" { + spv.Return + } + spv.EntryPoint "GLCompute" @call_dispatch_0 + spv.ExecutionMode @call_dispatch_0 "LocalSize", 32, 1, 1 + } + } + } +} +hal.executable private @call_dispatch_1 { + hal.executable.variant @vulkan_spirv_fb, target = #executable_target_vulkan_spirv_fb { + hal.executable.export @call_dispatch_1 ordinal(0) layout(#executable_layout_1) + builtin.module { + spv.module Logical GLSL450 requires #spv.vce { + spv.func @call_dispatch_1() "None" { + spv.Return + } + spv.EntryPoint "GLCompute" @call_dispatch_1 + spv.ExecutionMode @call_dispatch_1 "LocalSize", 4, 4, 1 + } + } + } +} +hal.executable private @call_dispatch_2 { + hal.executable.variant @vulkan_spirv_fb, target = #executable_target_vulkan_spirv_fb { + hal.executable.export @call_dispatch_2 ordinal(0) layout(#executable_layout_0) + builtin.module { + spv.module Logical GLSL450 requires #spv.vce { + spv.func @call_dispatch_2() "None" { + spv.Return + } + spv.EntryPoint "GLCompute" @call_dispatch_2 + spv.ExecutionMode @call_dispatch_2 "LocalSize", 32, 1, 1 + } + } + } +} +hal.executable private @call_dispatch_3 { + hal.executable.variant @vulkan_spirv_fb, target = #executable_target_vulkan_spirv_fb { + hal.executable.export @call_dispatch_3 ordinal(0) layout(#executable_layout_1) + builtin.module { + spv.module Logical GLSL450 requires #spv.vce { + spv.func @call_dispatch_3() "None" { + spv.Return + } + spv.EntryPoint "GLCompute" @call_dispatch_3 + spv.ExecutionMode @call_dispatch_3 "LocalSize", 8, 2, 2 + } + } + } +} +hal.executable private @call_dispatch_4 { + hal.executable.variant @vulkan_spirv_fb, target = #executable_target_vulkan_spirv_fb { + hal.executable.export @call_dispatch_4 ordinal(0) layout(#executable_layout_1) + builtin.module { + spv.module Logical GLSL450 requires #spv.vce { + spv.func @call_dispatch_4() "None" { + spv.Return + } + spv.EntryPoint "GLCompute" @call_dispatch_4 + spv.ExecutionMode @call_dispatch_4 "LocalSize", 2, 8, 1 + } + } + } +} + +// Two groups should be created, according to their interfaces. + +// CHECK: hal.executable private @linking_linked_vulkan_0 { +// CHECK-NEXT: hal.executable.variant public @vulkan_spirv_fb, target = #executable_target_vulkan_spirv_fb { +// CHECK-NEXT: hal.executable.export public @call_dispatch_1 ordinal(0) layout(#executable_layout_0) +// CHECK-NEXT: hal.executable.export public @call_dispatch_3 ordinal(1) layout(#executable_layout_0) +// CHECK-NEXT: hal.executable.export public @call_dispatch_4 ordinal(2) layout(#executable_layout_0) +// CHECK-NEXT: module { +// CHECK-NEXT: spv.module Logical GLSL450 requires #spv.vce { +// CHECK-NEXT: spv.func @call_dispatch_1() "None" { +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } +// CHECK-NEXT: spv.EntryPoint "GLCompute" @call_dispatch_1 +// CHECK-NEXT: spv.ExecutionMode @call_dispatch_1 "LocalSize", 4, 4, 1 +// CHECK-NEXT: spv.func @call_dispatch_3() "None" { +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } +// CHECK-NEXT: spv.EntryPoint "GLCompute" @call_dispatch_3 +// CHECK-NEXT: spv.ExecutionMode @call_dispatch_3 "LocalSize", 8, 2, 2 +// CHECK-NEXT: spv.func @call_dispatch_4() "None" { +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } +// CHECK-NEXT: spv.EntryPoint "GLCompute" @call_dispatch_4 +// CHECK-NEXT: spv.ExecutionMode @call_dispatch_4 "LocalSize", 2, 8, 1 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: hal.executable private @linking_linked_vulkan { +// CHECK-NEXT: hal.executable.variant public @vulkan_spirv_fb, target = #executable_target_vulkan_spirv_fb { +// CHECK-NEXT: hal.executable.export public @call_dispatch_0 ordinal(0) layout(#executable_layout_1) +// CHECK-NEXT: hal.executable.export public @call_dispatch_2 ordinal(1) layout(#executable_layout_1) +// CHECK-NEXT: module { +// CHECK-NEXT: spv.module Logical GLSL450 requires #spv.vce { +// CHECK-NEXT: spv.func @call_dispatch_0() "None" { +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } +// CHECK-NEXT: spv.EntryPoint "GLCompute" @call_dispatch_0 +// CHECK-NEXT: spv.ExecutionMode @call_dispatch_0 "LocalSize", 32, 1, 1 +// CHECK-NEXT: spv.func @call_dispatch_2() "None" { +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } +// CHECK-NEXT: spv.EntryPoint "GLCompute" @call_dispatch_2 +// CHECK-NEXT: spv.ExecutionMode @call_dispatch_2 "LocalSize", 32, 1, 1 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/smoketest.mlir b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/smoketest.mlir new file mode 100644 index 0000000000000..585b037919484 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/test/smoketest.mlir @@ -0,0 +1,39 @@ +// RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline %s | FileCheck %s + +module attributes { + hal.device.targets = [ + #hal.device.target<"vulkan", { + executable_targets = [ + #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", { + spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>> + }> + ] + }> + ] +} { + +stream.executable public @reduce_dispatch { + stream.executable.export @reduce_dispatch + builtin.module { + func.func @reduce_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding) { + %c0 = arith.constant 0 : index + %arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor + %arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor + %0 = linalg.init_tensor [] : tensor + %1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor -> tensor<16xf32> + %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%1 : tensor<16xf32>) outs(%0 : tensor) { + ^bb0(%arg2: f32, %arg3: f32): + %4 = arith.addf %arg2, %arg3 : f32 + linalg.yield %4 : f32 + } -> tensor + flow.dispatch.tensor.store %3, %arg1, offsets=[], sizes=[], strides=[] : tensor -> !flow.dispatch.tensor + return + } + } +} + +} + +// CHECK: hal.executable.binary public @vulkan_spirv_fb attributes +// CHECK-SAME: data = dense +// CHECK-SAME: format = "vulkan-spirv-fb" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp index 829dbefa7b871..77ea5be304991 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp @@ -131,7 +131,9 @@ class VulkanSPIRVTargetBackend : public TargetBackend { if (variantOp.isExternal()) return; - buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/false); + buildSPIRVCodegenPassPipeline( + passManager, /*enableFastMath=*/false, + /*addressingModel=*/spirv::AddressingModel::Logical); } LogicalResult serializeExecutable(const SerializationOptions &options, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp index 0905020df794e..5a2d61f3fd55f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp @@ -114,7 +114,9 @@ class WebGPUTargetBackend : public TargetBackend { // ways to check whether a floating point number is NaN or infinity. // Therefore, just let the SPIR-V CodeGen to avoid generating guards w.r.t. // NaN and infinity. - buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/true); + buildSPIRVCodegenPassPipeline( + passManager, /*enableFastMath=*/true, + /*addressingModel=*/spirv::AddressingModel::Logical); // WGSL does not support extended multiplication: // https://github.com/gpuweb/gpuweb/issues/1565. Make sure to lower it to diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt index 69aac549af588..2802201a48b2e 100644 --- a/compiler/src/iree/compiler/Tools/CMakeLists.txt +++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt @@ -17,6 +17,10 @@ if(IREE_TARGET_BACKEND_METAL_SPIRV) list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::MetalSPIRV) list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_METALSPIRV_TARGET") endif() +if(IREE_TARGET_BACKEND_OPENCL_SPIRV) + list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::OpenCLSPIRV) + list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_OPENCLSPIRV_TARGET") +endif() if(IREE_TARGET_BACKEND_VMVX) list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::VMVX) list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_VMVX_TARGET") diff --git a/compiler/src/iree/compiler/Tools/init_targets.cc b/compiler/src/iree/compiler/Tools/init_targets.cc index 6da09899846a6..f7a6df5c6c34f 100644 --- a/compiler/src/iree/compiler/Tools/init_targets.cc +++ b/compiler/src/iree/compiler/Tools/init_targets.cc @@ -14,6 +14,9 @@ #ifdef IREE_HAVE_METALSPIRV_TARGET #include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.h" #endif // IREE_HAVE_METALSPIRV_TARGET +#ifdef IREE_HAVE_OPENCLSPIRV_TARGET +#include "iree/compiler/Dialect/HAL/Target/OpenCLSPIRV/OpenCLSPIRVTarget.h" +#endif // IREE_HAVE_OPENCLSPIRV_TARGET #ifdef IREE_HAVE_ROCM_TARGET #include "iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.h" #endif // IREE_HAVE_ROCM_TARGET @@ -44,6 +47,10 @@ void registerHALTargetBackends() { #ifdef IREE_HAVE_METALSPIRV_TARGET IREE::HAL::registerMetalSPIRVTargetBackends(); #endif // IREE_HAVE_METALSPIRV_TARGET +#ifdef IREE_HAVE_OPENCLSPIRV_TARGET + IREE::HAL::registerOpenCLSPIRVTargetBackends( + []() { return IREE::HAL::getOpenCLSPIRVTargetOptionsFromFlags(); }); +#endif // IREE_HAVE_OPENCLSPIRV_TARGET #ifdef IREE_HAVE_ROCM_TARGET IREE::HAL::registerROCMTargetBackends(); #endif // IREE_HAVE_ROCM_TARGET diff --git a/experimental/level_zero/CMakeLists.txt b/experimental/level_zero/CMakeLists.txt new file mode 100644 index 0000000000000..139c53c918b30 --- /dev/null +++ b/experimental/level_zero/CMakeLists.txt @@ -0,0 +1,105 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +set(IREE_PACKAGE_ROOT_DIR ${CMAKE_CURRENT_LIST_DIR}/../..) +# Canonicalize path. +cmake_path(ABSOLUTE_PATH IREE_PACKAGE_ROOT_DIR + BASE_DIRECTORY ${IREE_PACKAGE_ROOT_DIR} + NORMALIZE + OUTPUT_VARIABLE IREE_PACKAGE_ROOT_DIR) +set(IREE_PACKAGE_ROOT_PREFIX iree) + +iree_add_all_subdirs() + +if(NOT LEVEL_ZERO_HEADERS_API_ROOT) + set(LEVEL_ZERO_HEADERS_API_ROOT "${IREE_ROOT_DIR}/third_party/level-zero/") + message(STATUS "Using default level-zero directory at ${LEVEL_ZERO_HEADERS_API_ROOT}") +endif() + +if(EXISTS ${LEVEL_ZERO_HEADERS_API_ROOT}) + message(STATUS "Level Zero Header Path: ${LEVEL_ZERO_HEADERS_API_ROOT}") +else() + message(SEND_ERROR "Could not locate Level Zero: ${LEVEL_ZERO_HEADERS_API_ROOT}") +endif() + +iree_cc_library( + NAME + level_zero + HDRS + "api.h" + SRCS + "api.h" + "context_wrapper.h" + "level_zero_allocator.c" + "level_zero_allocator.h" + "level_zero_buffer.c" + "level_zero_buffer.h" + "level_zero_device.c" + "level_zero_device.h" + "level_zero_driver.c" + "level_zero_event.c" + "level_zero_event.h" + "event_semaphore.c" + "event_semaphore.h" + "pipeline_layout.c" + "pipeline_layout.h" + "direct_command_buffer.c" + "direct_command_buffer.h" + "native_executable.c" + "native_executable.h" + "nop_executable_cache.c" + "nop_executable_cache.h" + "status_util.c" + "status_util.h" + INCLUDES + "${CMAKE_CURRENT_LIST_DIR}/../.." + "${PROJECT_BINARY_DIR}" + "${LEVEL_ZERO_HEADERS_API_ROOT}" + DEPS + ::dynamic_symbols + iree::base + iree::base::internal + iree::base::internal::arena + iree::base::internal::flatcc::parsing + iree::base::internal::synchronization + iree::hal + iree::hal::utils::buffer_transfer + iree::hal::utils::semaphore_base + iree::schemas::level_zero_executable_def_c_fbs + PUBLIC +) + +iree_cc_library( + NAME + dynamic_symbols + HDRS + "dynamic_symbols.h" + TEXTUAL_HDRS + "dynamic_symbol_tables.h" + SRCS + "level_zero_headers.h" + "dynamic_symbols.c" + INCLUDES + "${LEVEL_ZERO_HEADERS_API_ROOT}/include" + "${CMAKE_CURRENT_LIST_DIR}/../.." + DEPS + iree::base::internal::dynamic_library + PUBLIC +) + +iree_cc_test( + NAME + dynamic_symbols_test + SRCS + "dynamic_symbols_test.cc" + DEPS + ::dynamic_symbols + iree::base + iree::testing::gtest + iree::testing::gtest_main + LABELS + "driver=level_zero" +) diff --git a/experimental/level_zero/api.h b/experimental/level_zero/api.h new file mode 100644 index 0000000000000..61da353adb0dc --- /dev/null +++ b/experimental/level_zero/api.h @@ -0,0 +1,47 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// See iree/base/api.h for documentation on the API conventions used. + +#ifndef IREE_HAL_LEVEL_ZERO_API_H_ +#define IREE_HAL_LEVEL_ZERO_API_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// iree_hal_level_zero_driver_t +//===----------------------------------------------------------------------===// + +// LEVEL_ZERO driver creation options. +typedef struct iree_hal_level_zero_driver_options_t { + // Index of the default LEVEL_ZERO device to use within the list of available + // devices. + int default_device_index; +} iree_hal_level_zero_driver_options_t; + +IREE_API_EXPORT void iree_hal_level_zero_driver_options_initialize( + iree_hal_level_zero_driver_options_t *out_options); + +// Creates a LEVEL_ZERO HAL driver that manage its own level zero context. +// +// |out_driver| must be released by the caller (see |iree_hal_driver_release|). +IREE_API_EXPORT iree_status_t iree_hal_level_zero_driver_create( + iree_string_view_t identifier, + const iree_hal_level_zero_driver_options_t *options, + iree_allocator_t host_allocator, iree_hal_driver_t **out_driver); + +// TODO(thomasraoux): Support importing a CUcontext from app. + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LEVEL_ZERO_API_H_ diff --git a/experimental/level_zero/context_wrapper.h b/experimental/level_zero/context_wrapper.h new file mode 100644 index 0000000000000..3a1984b3853a2 --- /dev/null +++ b/experimental/level_zero/context_wrapper.h @@ -0,0 +1,22 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_LEVEL_ZERO_CONTEXT_WRAPPER_H_ +#define IREE_HAL_LEVEL_ZERO_CONTEXT_WRAPPER_H_ + +#include "experimental/level_zero/dynamic_symbols.h" +#include "experimental/level_zero/level_zero_headers.h" +#include "iree/hal/api.h" + +// Structure to wrap all objects constant within a context. This makes it +// simpler to pass it to the different objects and saves memory. +typedef struct iree_hal_level_zero_context_wrapper_t { + ze_context_handle_t level_zero_context; + iree_allocator_t host_allocator; + iree_hal_level_zero_dynamic_symbols_t *syms; +} iree_hal_level_zero_context_wrapper_t; + +#endif // IREE_HAL_LEVEL_ZERO_CONTEXT_WRAPPER_H_ diff --git a/experimental/level_zero/cts/CMakeLists.txt b/experimental/level_zero/cts/CMakeLists.txt new file mode 100644 index 0000000000000..c3e5062fd387d --- /dev/null +++ b/experimental/level_zero/cts/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_hal_cts_test_suite( + DRIVER_NAME + level_zero + DRIVER_REGISTRATION_HDR + "experimental/level_zero/registration/driver_module.h" + DRIVER_REGISTRATION_FN + "iree_hal_level_zero_driver_module_register" + COMPILER_TARGET_BACKEND + "opencl-spirv" + EXECUTABLE_FORMAT + "\"ZERO\"" + DEPS + iree::experimental::level_zero::registration + EXCLUDED_TESTS + # This test depends on iree_hal_level_zero_direct_command_buffer_update_buffer + # via iree_hal_buffer_view_allocate_buffer, which is not implemented yet. + "command_buffer_dispatch" + # Non-push descriptor sets are not implemented in the level_zero backend yet. + "descriptor_set" + # Semaphores are not implemented in the level_zero backend yet. + "semaphore_submission" + "semaphore" +) diff --git a/experimental/level_zero/direct_command_buffer.c b/experimental/level_zero/direct_command_buffer.c new file mode 100644 index 0000000000000..b589afa82f669 --- /dev/null +++ b/experimental/level_zero/direct_command_buffer.c @@ -0,0 +1,488 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/level_zero/direct_command_buffer.h" + +#include +#include +#include + +#include "experimental/level_zero/dynamic_symbols.h" +#include "experimental/level_zero/level_zero_buffer.h" +#include "experimental/level_zero/level_zero_event.h" +#include "experimental/level_zero/native_executable.h" +#include "experimental/level_zero/pipeline_layout.h" +#include "experimental/level_zero/status_util.h" +#include "iree/base/api.h" +#include "iree/base/internal/inline_array.h" + +// Command buffer implementation that directly maps to level_zero direct. +// This records the commands on the calling thread without additional threading +// indirection. + +typedef struct { + iree_hal_command_buffer_t base; + iree_hal_level_zero_context_wrapper_t* context; + iree_arena_block_pool_t* block_pool; + ze_command_list_handle_t command_list; + + // Keep track of the current set of kernel arguments. + int32_t push_constant[IREE_HAL_LEVEL_ZERO_MAX_PUSH_CONSTANT_COUNT]; + void* current_descriptor[]; +} iree_hal_level_zero_direct_command_buffer_t; + +#define IREE_HAL_LEVEL_ZERO_MAX_BINDING_COUNT 64 +// Kernel arguments contains binding and push constants. +#define IREE_HAL_LEVEL_ZERO_MAX_KERNEL_ARG 128 + +static const iree_hal_command_buffer_vtable_t + iree_hal_level_zero_direct_command_buffer_vtable; + +static iree_hal_level_zero_direct_command_buffer_t* +iree_hal_level_zero_direct_command_buffer_cast( + iree_hal_command_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_level_zero_direct_command_buffer_vtable); + return (iree_hal_level_zero_direct_command_buffer_t*)base_value; +} +// TODO: Create helper function to get cmdlist out to device for submissions. +iree_status_t iree_hal_level_zero_direct_command_buffer_create( + iree_hal_device_t* device, iree_hal_level_zero_context_wrapper_t* context, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, + iree_arena_block_pool_t* block_pool, ze_device_handle_t level_zero_device, + uint32_t command_queue_ordinal, + iree_hal_command_buffer_t** out_command_buffer) { + *out_command_buffer = NULL; + + if (binding_capacity > 0) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "indirect command buffers not yet implemented"); + } + + IREE_ASSERT_ARGUMENT(context); + IREE_ASSERT_ARGUMENT(block_pool); + IREE_ASSERT_ARGUMENT(out_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_level_zero_direct_command_buffer_t* command_buffer = NULL; + size_t total_size = sizeof(*command_buffer) + + IREE_HAL_LEVEL_ZERO_MAX_KERNEL_ARG * sizeof(void*) + + IREE_HAL_LEVEL_ZERO_MAX_KERNEL_ARG * + sizeof(iree_hal_level_zero_device_ptr_t); + iree_status_t status = iree_allocator_malloc( + context->host_allocator, total_size, (void**)&command_buffer); + if (iree_status_is_ok(status)) { + iree_hal_command_buffer_initialize( + device, mode, command_categories, queue_affinity, binding_capacity, + &iree_hal_level_zero_direct_command_buffer_vtable, + &command_buffer->base); + command_buffer->context = context; + command_buffer->block_pool = block_pool; + iree_hal_level_zero_device_ptr_t* device_ptrs = + (iree_hal_level_zero_device_ptr_t*)(command_buffer->current_descriptor + + IREE_HAL_LEVEL_ZERO_MAX_KERNEL_ARG); + for (size_t i = 0; i < IREE_HAL_LEVEL_ZERO_MAX_KERNEL_ARG; i++) { + command_buffer->current_descriptor[i] = &device_ptrs[i]; + } + // Create a command list + ze_command_list_handle_t command_list; + ze_command_list_desc_t command_list_desc = { + .stype = ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC}; + command_list_desc.commandQueueGroupOrdinal = command_queue_ordinal; + LEVEL_ZERO_RETURN_IF_ERROR( + command_buffer->context->syms, + zeCommandListCreateImmediate( + command_buffer->context->level_zero_context, level_zero_device, + &command_list_desc, &command_list), + "zeCommandListCreateImmediate"); + command_buffer->command_list = command_list; + + *out_command_buffer = &command_buffer->base; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_level_zero_direct_command_buffer_destroy( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_level_zero_direct_command_buffer_t* command_buffer = + iree_hal_level_zero_direct_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(command_buffer->context->host_allocator, command_buffer); + + IREE_TRACE_ZONE_END(z0); +} + +bool iree_hal_level_zero_direct_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer) { + return iree_hal_resource_is( + &command_buffer->resource, + &iree_hal_level_zero_direct_command_buffer_vtable); +} + +static void* iree_hal_level_zero_direct_command_buffer_dyn_cast( + iree_hal_command_buffer_t* command_buffer, const void* vtable) { + if (vtable == &iree_hal_level_zero_direct_command_buffer_vtable) { + IREE_HAL_ASSERT_TYPE(command_buffer, vtable); + return command_buffer; + } + return NULL; +} + +static iree_status_t iree_hal_level_zero_direct_command_buffer_begin( + iree_hal_command_buffer_t* base_command_buffer) { + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_direct_command_buffer_end( + iree_hal_command_buffer_t* base_command_buffer) { + return iree_ok_status(); +} + +static void iree_hal_level_zero_direct_command_buffer_begin_debug_group( + iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label, + iree_hal_label_color_t label_color, + const iree_hal_label_location_t* location) { + // TODO(benvanik): tracy event stack. +} + +static void iree_hal_level_zero_direct_command_buffer_end_debug_group( + iree_hal_command_buffer_t* base_command_buffer) { + // TODO(benvanik): tracy event stack. +} + +static iree_status_t +iree_hal_level_zero_direct_command_buffer_execution_barrier( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_hal_execution_barrier_flags_t flags, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + iree_hal_level_zero_direct_command_buffer_t* command_buffer = + iree_hal_level_zero_direct_command_buffer_cast(base_command_buffer); + LEVEL_ZERO_RETURN_IF_ERROR( + command_buffer->context->syms, + zeCommandListAppendBarrier(command_buffer->command_list, NULL, 0, NULL), + "zeCommandListAppendMemoryFill"); + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_direct_command_buffer_signal_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + // TODO: Implement barrier + iree_hal_level_zero_direct_command_buffer_t* command_buffer = + iree_hal_level_zero_direct_command_buffer_cast(base_command_buffer); + LEVEL_ZERO_RETURN_IF_ERROR( + command_buffer->context->syms, + zeCommandListAppendSignalEvent(command_buffer->command_list, + iree_hal_level_zero_event_handle(event)), + "zeCommandListAppendSignalEvent"); + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_direct_command_buffer_reset_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + iree_hal_level_zero_direct_command_buffer_t* command_buffer = + iree_hal_level_zero_direct_command_buffer_cast(base_command_buffer); + LEVEL_ZERO_RETURN_IF_ERROR( + command_buffer->context->syms, + zeCommandListAppendEventReset(command_buffer->command_list, + iree_hal_level_zero_event_handle(event)), + "zeCommandListAppendEventReset"); + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_direct_command_buffer_wait_events( + iree_hal_command_buffer_t* base_command_buffer, + iree_host_size_t event_count, const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + iree_hal_level_zero_direct_command_buffer_t* command_buffer = + iree_hal_level_zero_direct_command_buffer_cast(base_command_buffer); + iree_inline_array(ze_event_handle_t, event_handles, event_count, + command_buffer->context->host_allocator); + for (int i = 0; i < event_count; ++i) { + *iree_inline_array_at(event_handles, i) = + iree_hal_level_zero_event_handle(events[i]); + } + LEVEL_ZERO_RETURN_IF_ERROR( + command_buffer->context->syms, + zeCommandListAppendWaitOnEvents(command_buffer->command_list, event_count, + iree_inline_array_data(event_handles)), + "zeCommandListAppendWaitOnEvents"); + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_direct_command_buffer_discard_buffer( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) { + // nothing to do. + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_direct_command_buffer_fill_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, const void* pattern, + iree_host_size_t pattern_length) { + iree_hal_level_zero_direct_command_buffer_t* command_buffer = + iree_hal_level_zero_direct_command_buffer_cast(base_command_buffer); + + iree_hal_level_zero_device_ptr_t target_device_buffer = + iree_hal_level_zero_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_buffer)); + target_offset += iree_hal_buffer_byte_offset(target_buffer); + iree_hal_level_zero_device_ptr_t dst = + (iree_hal_level_zero_device_ptr_t)((uintptr_t)(void*) + target_device_buffer + + target_offset); + LEVEL_ZERO_RETURN_IF_ERROR( + command_buffer->context->syms, + zeCommandListAppendMemoryFill(command_buffer->command_list, dst, pattern, + pattern_length, length, NULL, 0, NULL), + "zeCommandListAppendMemoryFill"); + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_direct_command_buffer_update_buffer( + iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, + iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "need level_zero implementation"); +} + +static iree_status_t iree_hal_level_zero_direct_command_buffer_copy_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length) { + iree_hal_level_zero_direct_command_buffer_t* command_buffer = + iree_hal_level_zero_direct_command_buffer_cast(base_command_buffer); + + iree_hal_level_zero_device_ptr_t target_device_buffer = + iree_hal_level_zero_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_buffer)); + target_offset += iree_hal_buffer_byte_offset(target_buffer); + iree_hal_level_zero_device_ptr_t source_device_buffer = + iree_hal_level_zero_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(source_buffer)); + source_offset += iree_hal_buffer_byte_offset(source_buffer); + iree_hal_level_zero_device_ptr_t dst = + (iree_hal_level_zero_device_ptr_t)((uintptr_t)(void*) + target_device_buffer + + target_offset); + iree_hal_level_zero_device_ptr_t src = + (iree_hal_level_zero_device_ptr_t)((uintptr_t)(void*) + source_device_buffer + + source_offset); + // TODO(raikonenfnu): Currently using NULL stream, need to figure out way to + // access proper stream from command buffer + LEVEL_ZERO_RETURN_IF_ERROR( + command_buffer->context->syms, + zeCommandListAppendMemoryCopy(command_buffer->command_list, dst, src, + length, NULL, 0, NULL), + "zeCommandListAppendMemoryCopy"); + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_direct_command_buffer_push_constants( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset, + const void* values, iree_host_size_t values_length) { + iree_hal_level_zero_direct_command_buffer_t* command_buffer = + iree_hal_level_zero_direct_command_buffer_cast(base_command_buffer); + iree_host_size_t constant_base_index = offset / sizeof(int32_t); + for (iree_host_size_t i = 0; i < values_length / sizeof(int32_t); i++) { + command_buffer->push_constant[i + constant_base_index] = + ((uint32_t*)values)[i]; + } + return iree_ok_status(); +} + +// Tie together the binding index and its index in |bindings| array. +typedef struct { + uint32_t index; + uint32_t binding; +} iree_hal_level_zero_binding_mapping_t; + +// Helper to sort the binding based on their binding index. +static int compare_binding_index(const void* a, const void* b) { + const iree_hal_level_zero_binding_mapping_t buffer_a = + *(const iree_hal_level_zero_binding_mapping_t*)a; + const iree_hal_level_zero_binding_mapping_t buffer_b = + *(const iree_hal_level_zero_binding_mapping_t*)b; + return buffer_a.binding < buffer_b.binding ? -1 : 1; +} + +static iree_status_t +iree_hal_level_zero_direct_command_buffer_push_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings) { + iree_hal_level_zero_direct_command_buffer_t* command_buffer = + iree_hal_level_zero_direct_command_buffer_cast(base_command_buffer); + iree_host_size_t base_binding = + iree_hal_level_zero_base_binding_index(pipeline_layout, set); + // Convention with the compiler side. We map bindings to kernel argument. + // We compact the bindings to get a dense set of arguments and keep them order + // based on the binding index. + // Sort the binding based on the binding index and map the array index to the + // argument index. + iree_hal_level_zero_binding_mapping_t + binding_used[IREE_HAL_LEVEL_ZERO_MAX_BINDING_COUNT]; + for (iree_host_size_t i = 0; i < binding_count; i++) { + iree_hal_level_zero_binding_mapping_t buffer = {i, bindings[i].binding}; + binding_used[i] = buffer; + } + qsort(binding_used, binding_count, + sizeof(iree_hal_level_zero_binding_mapping_t), compare_binding_index); + assert(binding_count < IREE_HAL_LEVEL_ZERO_MAX_BINDING_COUNT && + "binding count larger than the max expected."); + for (iree_host_size_t i = 0; i < binding_count; i++) { + iree_hal_descriptor_set_binding_t binding = bindings[binding_used[i].index]; + iree_hal_level_zero_device_ptr_t device_ptr = + binding.buffer + ? (iree_hal_level_zero_device_ptr_t)((uintptr_t)(void*) + iree_hal_level_zero_buffer_device_pointer( + iree_hal_buffer_allocated_buffer( + binding.buffer)) + + iree_hal_buffer_byte_offset( + binding.buffer) + + binding.offset) + : 0; + *((iree_hal_level_zero_device_ptr_t*) + command_buffer->current_descriptor[i + base_binding]) = device_ptr; + } + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_direct_command_buffer_dispatch( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + iree_hal_level_zero_direct_command_buffer_t* command_buffer = + iree_hal_level_zero_direct_command_buffer_cast(base_command_buffer); + iree_hal_level_zero_direct_command_buffer_cast(base_command_buffer); + iree_hal_pipeline_layout_t* layout = + iree_hal_level_zero_executable_get_layout(executable, entry_point); + iree_host_size_t num_constants = + iree_hal_level_zero_pipeline_layout_num_constants(layout); + iree_host_size_t constant_base_index = + iree_hal_level_zero_push_constant_index(layout); + + int32_t block_size_x, block_size_y, block_size_z; + IREE_RETURN_IF_ERROR(iree_hal_level_zero_native_executable_block_size( + executable, entry_point, &block_size_x, &block_size_y, &block_size_z)); + ze_kernel_handle_t func = + iree_hal_level_zero_native_executable_for_entry_point(executable, + entry_point); + // TODO(raikonenfnu): Currently using NULL stream, need to figure out way to + // access proper stream from command buffer + LEVEL_ZERO_RETURN_IF_ERROR( + command_buffer->context->syms, + zeKernelSetGroupSize(func, block_size_x, block_size_y, block_size_z), + "zeKernelSetGroupSize"); + + // Patch the push constants in the kernel arguments. + for (iree_host_size_t i = 0; i < num_constants; i++) { + *((uint32_t*)command_buffer->current_descriptor[i + constant_base_index]) = + command_buffer->push_constant[i]; + } + iree_host_size_t num_kernel_args = constant_base_index + num_constants; + for (iree_host_size_t i = 0; i < num_kernel_args; i++) { + LEVEL_ZERO_RETURN_IF_ERROR( + command_buffer->context->syms, + zeKernelSetArgumentValue(func, i, + sizeof(command_buffer->current_descriptor[i]), + command_buffer->current_descriptor[i]), + "zeKernelSetArgumentValue"); + } + + // Kernel thread-dispatch + ze_group_count_t dispatch; + dispatch.groupCountX = workgroup_x; + dispatch.groupCountY = workgroup_y; + dispatch.groupCountZ = workgroup_z; + + // Launch kernel on the GPU + LEVEL_ZERO_RETURN_IF_ERROR( + command_buffer->context->syms, + zeCommandListAppendLaunchKernel(command_buffer->command_list, func, + &dispatch, NULL, 0, NULL), + "zeCommandListAppendLaunchKernel"); + return iree_ok_status(); +} + +static iree_status_t +iree_hal_level_zero_direct_command_buffer_dispatch_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_t* workgroups_buffer, + iree_device_size_t workgroups_offset) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "need level_zero implementation"); +} + +ze_command_list_handle_t iree_hal_level_zero_direct_command_buffer_exec( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_level_zero_direct_command_buffer_t* command_buffer = + iree_hal_level_zero_direct_command_buffer_cast(base_command_buffer); + IREE_ASSERT_TRUE(command_buffer); + return command_buffer->command_list; +} + +static iree_status_t iree_hal_rocm_direct_command_buffer_execute_commands( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_command_buffer_t* base_commands, + iree_hal_buffer_binding_table_t binding_table) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "indirect command buffers not yet implemented"); +} + +static const iree_hal_command_buffer_vtable_t + iree_hal_level_zero_direct_command_buffer_vtable = { + .destroy = iree_hal_level_zero_direct_command_buffer_destroy, + .begin = iree_hal_level_zero_direct_command_buffer_begin, + .end = iree_hal_level_zero_direct_command_buffer_end, + .begin_debug_group = + iree_hal_level_zero_direct_command_buffer_begin_debug_group, + .end_debug_group = + iree_hal_level_zero_direct_command_buffer_end_debug_group, + .execution_barrier = + iree_hal_level_zero_direct_command_buffer_execution_barrier, + .signal_event = iree_hal_level_zero_direct_command_buffer_signal_event, + .reset_event = iree_hal_level_zero_direct_command_buffer_reset_event, + .wait_events = iree_hal_level_zero_direct_command_buffer_wait_events, + .discard_buffer = + iree_hal_level_zero_direct_command_buffer_discard_buffer, + .fill_buffer = iree_hal_level_zero_direct_command_buffer_fill_buffer, + .update_buffer = + iree_hal_level_zero_direct_command_buffer_update_buffer, + .copy_buffer = iree_hal_level_zero_direct_command_buffer_copy_buffer, + .push_constants = + iree_hal_level_zero_direct_command_buffer_push_constants, + .push_descriptor_set = + iree_hal_level_zero_direct_command_buffer_push_descriptor_set, + .dispatch = iree_hal_level_zero_direct_command_buffer_dispatch, + .dispatch_indirect = + iree_hal_level_zero_direct_command_buffer_dispatch_indirect, + .execute_commands = + iree_hal_rocm_direct_command_buffer_execute_commands, +}; diff --git a/experimental/level_zero/direct_command_buffer.h b/experimental/level_zero/direct_command_buffer.h new file mode 100644 index 0000000000000..1cb109c7ea589 --- /dev/null +++ b/experimental/level_zero/direct_command_buffer.h @@ -0,0 +1,56 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_LEVEL_ZERO_DIRECT_COMMAND_BUFFER_H_ +#define IREE_HAL_LEVEL_ZERO_DIRECT_COMMAND_BUFFER_H_ + +#include "experimental/level_zero/context_wrapper.h" +#include "experimental/level_zero/dynamic_symbols.h" +#include "experimental/level_zero/level_zero_headers.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct iree_arena_block_pool_t iree_arena_block_pool_t; + +// Level Zero Kernel Information Structure +typedef struct { + ze_kernel_handle_t func; + unsigned int gridDimX; + unsigned int gridDimY; + unsigned int gridDimZ; + unsigned int blockDimX; + unsigned int blockDimY; + unsigned int blockDimZ; + void** kernelParams; +} level_zero_launch_params; + +// Creates a Level Zero direct command buffer. +iree_status_t iree_hal_level_zero_direct_command_buffer_create( + iree_hal_device_t* device, iree_hal_level_zero_context_wrapper_t* context, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, + iree_arena_block_pool_t* block_pool, ze_device_handle_t level_zero_device, + uint32_t command_queue_ordinal, + iree_hal_command_buffer_t** out_command_buffer); + +// Returns associated command_list from command buffer. +ze_command_list_handle_t iree_hal_level_zero_direct_command_buffer_exec( + iree_hal_command_buffer_t* command_buffer); + +// Returns true if |command_buffer| is a Level Zero command buffer. +bool iree_hal_level_zero_direct_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LEVEL_ZERO_DIRECT_COMMAND_BUFFER_H_ diff --git a/experimental/level_zero/dynamic_symbol_tables.h b/experimental/level_zero/dynamic_symbol_tables.h new file mode 100644 index 0000000000000..6503bfce43f82 --- /dev/null +++ b/experimental/level_zero/dynamic_symbol_tables.h @@ -0,0 +1,90 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +/* Context/Device/Runtime related APIs */ +ZE_PFN_DECL(zeInit, ze_init_flags_t) +ZE_PFN_DECL(zeContextCreate, ze_driver_handle_t, const ze_context_desc_t *, + ze_context_handle_t *) +ZE_PFN_DECL(zeContextDestroy, ze_context_handle_t) +ZE_PFN_DECL(zeDriverGet, uint32_t *, + ze_driver_handle_t *) // No direct, need to modify +ZE_PFN_DECL(zeDeviceGet, ze_driver_handle_t, uint32_t *, + ze_device_handle_t *) // Can get device handle and count. +ZE_PFN_DECL(zeDeviceGetProperties, ze_device_handle_t, + ze_device_properties_t *) // Can get name. + +/* Command Buffer/List related APIs*/ +ZE_PFN_DECL(zeDeviceGetCommandQueueGroupProperties, ze_device_handle_t, + uint32_t *, ze_command_queue_group_properties_t *) +ZE_PFN_DECL(zeCommandQueueCreate, ze_context_handle_t, ze_device_handle_t, + const ze_command_queue_desc_t *, ze_command_queue_handle_t *) +ZE_PFN_DECL(zeCommandListCreate, ze_context_handle_t, ze_device_handle_t, + const ze_command_list_desc_t *, ze_command_list_handle_t *) +ZE_PFN_DECL(zeCommandListCreateImmediate, ze_context_handle_t, + ze_device_handle_t, const ze_command_list_desc_t *, + ze_command_list_handle_t *) +ZE_PFN_DECL(zeCommandListAppendLaunchKernel, ze_command_list_handle_t, + ze_kernel_handle_t, const ze_group_count_t *, ze_event_handle_t, + uint32_t, ze_event_handle_t *) +ZE_PFN_DECL(zeCommandListClose, ze_command_list_handle_t) +ZE_PFN_DECL(zeCommandQueueExecuteCommandLists, ze_command_queue_handle_t, + uint32_t, ze_command_list_handle_t *, ze_fence_handle_t) +ZE_PFN_DECL(zeCommandQueueSynchronize, ze_command_queue_handle_t, uint64_t) +ZE_PFN_DECL(zeCommandListDestroy, ze_command_list_handle_t) +ZE_PFN_DECL(zeCommandQueueDestroy, ze_command_queue_handle_t) + +/* Memory related APIs*/ +// NOTE: Asynchronous/Blocking set in flags. +// NOTE: Intel Shared Memory == Unified/Managed memory, where memory is shared +// between host and devices. +ZE_PFN_DECL(zeCommandListAppendMemoryFill, ze_command_list_handle_t, void *, + const void *, size_t, size_t, ze_event_handle_t, uint32_t, + ze_event_handle_t *) +ZE_PFN_DECL(zeCommandListAppendMemoryCopy, ze_command_list_handle_t, void *, + const void *, size_t, ze_event_handle_t, uint32_t, + ze_event_handle_t *) +ZE_PFN_DECL(zeMemAllocDevice, ze_context_handle_t, + const ze_device_mem_alloc_desc_t *, size_t, size_t, + ze_device_handle_t, void **) +ZE_PFN_DECL(zeMemAllocShared, ze_context_handle_t, + const ze_device_mem_alloc_desc_t *, + const ze_host_mem_alloc_desc_t *, size_t, size_t, + ze_device_handle_t, void **) +ZE_PFN_DECL(zeMemAllocHost, ze_context_handle_t, + const ze_host_mem_alloc_desc_t *, size_t, size_t, void **) +ZE_PFN_DECL(zeMemFree, ze_context_handle_t, void *) +ZE_PFN_DECL(zeCommandListAppendBarrier, ze_command_list_handle_t, + ze_event_handle_t, uint32_t, ze_event_handle_t *) +ZE_PFN_DECL(zeEventPoolCreate, ze_context_handle_t, + const ze_event_pool_desc_t *, uint32_t, ze_device_handle_t *, + ze_event_pool_handle_t *) +ZE_PFN_DECL(zeEventCreate, ze_event_pool_handle_t, const ze_event_desc_t *, + ze_event_handle_t *) +ZE_PFN_DECL(zeEventDestroy, ze_event_handle_t) +ZE_PFN_DECL(zeEventPoolDestroy, ze_event_pool_handle_t) +ZE_PFN_DECL(zeEventHostSynchronize, ze_event_handle_t, uint64_t) +ZE_PFN_DECL(zeCommandListAppendSignalEvent, ze_command_list_handle_t, + ze_event_handle_t) +ZE_PFN_DECL(zeCommandListAppendWaitOnEvents, ze_command_list_handle_t, uint32_t, + ze_event_handle_t *) +ZE_PFN_DECL(zeCommandListAppendEventReset, ze_command_list_handle_t, + ze_event_handle_t) + +/* Kernel generation related APIs*/ +ZE_PFN_DECL(zeModuleCreate, ze_context_handle_t, ze_device_handle_t, + const ze_module_desc_t *, ze_module_handle_t *, + ze_module_build_log_handle_t *) +ZE_PFN_DECL(zeKernelCreate, ze_module_handle_t, const ze_kernel_desc_t *, + ze_kernel_handle_t *) +ZE_PFN_DECL(zeKernelSuggestGroupSize, ze_kernel_handle_t, uint32_t, uint32_t, + uint32_t, uint32_t *, uint32_t *, uint32_t *) +ZE_PFN_DECL(zeKernelSetGroupSize, ze_kernel_handle_t, uint32_t, uint32_t, + uint32_t) +ZE_PFN_DECL(zeKernelSetArgumentValue, ze_kernel_handle_t, uint32_t, size_t, + const void *) +ZE_PFN_DECL(zeModuleBuildLogDestroy, ze_module_build_log_handle_t) +ZE_PFN_DECL(zeModuleBuildLogGetString, ze_module_build_log_handle_t, size_t *, + char *) \ No newline at end of file diff --git a/experimental/level_zero/dynamic_symbols.c b/experimental/level_zero/dynamic_symbols.c new file mode 100644 index 0000000000000..53a9f41b8dc4f --- /dev/null +++ b/experimental/level_zero/dynamic_symbols.c @@ -0,0 +1,65 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/level_zero/dynamic_symbols.h" + +#include + +#include "iree/base/internal/dynamic_library.h" +#include "iree/base/target_platform.h" + +static const char* kLevelZeroLoaderSearchNames[] = { +#if defined(IREE_PLATFORM_WINDOWS) + NULL, +#else + "libze_loader.so", +#endif +}; + +static iree_status_t iree_hal_level_zero_dynamic_symbols_resolve_all( + iree_hal_level_zero_dynamic_symbols_t* syms) { +#define ZE_PFN_DECL(levelZeroSymbolName, ...) \ + { \ + static const char* kName = #levelZeroSymbolName; \ + IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \ + syms->loader_library, kName, (void**)&syms->levelZeroSymbolName)); \ + } +#include "experimental/level_zero/dynamic_symbol_tables.h" // IWYU pragma: keep +#undef ZE_PFN_DECL + return iree_ok_status(); +} + +iree_status_t iree_hal_level_zero_dynamic_symbols_initialize( + iree_allocator_t allocator, + iree_hal_level_zero_dynamic_symbols_t* out_syms) { + IREE_TRACE_ZONE_BEGIN(z0); + memset(out_syms, 0, sizeof(*out_syms)); + iree_status_t status = iree_dynamic_library_load_from_files( + IREE_ARRAYSIZE(kLevelZeroLoaderSearchNames), kLevelZeroLoaderSearchNames, + IREE_DYNAMIC_LIBRARY_FLAG_NONE, allocator, &out_syms->loader_library); + if (iree_status_is_not_found(status)) { + iree_status_ignore(status); + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "LevelZero runtime library not available; ensure " + "installed and on path"); + } + if (iree_status_is_ok(status)) { + status = iree_hal_level_zero_dynamic_symbols_resolve_all(out_syms); + } + if (!iree_status_is_ok(status)) { + iree_hal_level_zero_dynamic_symbols_deinitialize(out_syms); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +void iree_hal_level_zero_dynamic_symbols_deinitialize( + iree_hal_level_zero_dynamic_symbols_t* syms) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_dynamic_library_release(syms->loader_library); + memset(syms, 0, sizeof(*syms)); + IREE_TRACE_ZONE_END(z0); +} diff --git a/experimental/level_zero/dynamic_symbols.h b/experimental/level_zero/dynamic_symbols.h new file mode 100644 index 0000000000000..c1cca07fc6de9 --- /dev/null +++ b/experimental/level_zero/dynamic_symbols.h @@ -0,0 +1,48 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_LEVEL_ZERO_DYNAMIC_SYMBOLS_H_ +#define IREE_HAL_LEVEL_ZERO_DYNAMIC_SYMBOLS_H_ + +#include "experimental/level_zero/level_zero_headers.h" +#include "iree/base/api.h" +#include "iree/base/internal/dynamic_library.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// DynamicSymbols allow loading dynamically a subset of LEVEL_ZERO driver API. +// It loads all the function declared in `dynamic_symbol_tables.def` and fail if +// any of the symbol is not available. The functions signatures are matching +// the declarations in `ze_api.h`. +typedef struct iree_hal_level_zero_dynamic_symbols_t { + iree_dynamic_library_t* loader_library; + +#define ZE_PFN_DECL(levelZeroSymbolName, ...) \ + ze_result_t (*levelZeroSymbolName)(__VA_ARGS__); +#include "experimental/level_zero/dynamic_symbol_tables.h" // IWYU pragma: export +#undef ZE_PFN_DECL +} iree_hal_level_zero_dynamic_symbols_t; + +// Initializes |out_syms| in-place with dynamically loaded LEVEL_ZERO symbols. +// iree_hal_level_zero_dynamic_symbols_deinitialize must be used to release the +// library resources. +iree_status_t iree_hal_level_zero_dynamic_symbols_initialize( + iree_allocator_t allocator, + iree_hal_level_zero_dynamic_symbols_t* out_syms); + +// Deinitializes |syms| by unloading the backing library. All function pointers +// will be invalidated. They _may_ still work if there are other reasons the +// library remains loaded so be careful. +void iree_hal_level_zero_dynamic_symbols_deinitialize( + iree_hal_level_zero_dynamic_symbols_t* syms); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LEVEL_ZERO_DYNAMIC_SYMBOLS_H_ diff --git a/experimental/level_zero/dynamic_symbols_test.cc b/experimental/level_zero/dynamic_symbols_test.cc new file mode 100644 index 0000000000000..eaf9cd5bdd424 --- /dev/null +++ b/experimental/level_zero/dynamic_symbols_test.cc @@ -0,0 +1,85 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/level_zero/dynamic_symbols.h" + +#include + +#include "iree/base/api.h" +#include "iree/testing/gtest.h" + +namespace iree { +namespace hal { +namespace level_zero { +namespace { + +#define LEVEL_ZERO_CHECK_ERRORS(expr) \ + { \ + ze_result_t status = expr; \ + ASSERT_EQ(ZE_RESULT_SUCCESS, status); \ + } + +TEST(DynamicSymbolsTest, CreateFromSystemLoader) { + iree_hal_level_zero_dynamic_symbols_t symbols; + iree_status_t status = iree_hal_level_zero_dynamic_symbols_initialize( + iree_allocator_system(), &symbols); + if (!iree_status_is_ok(status)) { + std::cerr << "Symbols cannot be loaded, skipping test."; + GTEST_SKIP(); + } + + LEVEL_ZERO_CHECK_ERRORS(symbols.zeInit(0)); + // Get the driver + uint32_t driverCount = 0; + LEVEL_ZERO_CHECK_ERRORS(symbols.zeDriverGet(&driverCount, nullptr)); + ze_driver_handle_t driverHandle; + if (driverCount > 0) { + LEVEL_ZERO_CHECK_ERRORS(symbols.zeDriverGet(&driverCount, &driverHandle)); + } else { + std::cerr << "Cannot find Intel Level Zero driver, skipping test."; + GTEST_SKIP(); + } + + // Create the context + ze_context_desc_t contextDescription = {}; + contextDescription.stype = ZE_STRUCTURE_TYPE_CONTEXT_DESC; + ze_context_handle_t context; + LEVEL_ZERO_CHECK_ERRORS( + symbols.zeContextCreate(driverHandle, &contextDescription, &context)); + + // Get the device + uint32_t deviceCount = 0; + LEVEL_ZERO_CHECK_ERRORS( + symbols.zeDeviceGet(driverHandle, &deviceCount, nullptr)); + + ze_device_handle_t device; + if (deviceCount > 0) { + LEVEL_ZERO_CHECK_ERRORS( + symbols.zeDeviceGet(driverHandle, &deviceCount, &device)); + } else { + std::cerr << "Cannot find Intel Level Zero device, skipping test."; + GTEST_SKIP(); + } + + // Print basic properties of the device + ze_device_properties_t deviceProperties; + deviceProperties.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; + LEVEL_ZERO_CHECK_ERRORS( + symbols.zeDeviceGetProperties(device, &deviceProperties)); + std::cout << "Device : " << deviceProperties.name << "\n" + << "Type : " + << ((deviceProperties.type == ZE_DEVICE_TYPE_GPU) ? "GPU" : "FPGA") + << "\n" + << "Vendor ID: " << std::hex << deviceProperties.vendorId + << std::dec << "\n"; + + iree_hal_level_zero_dynamic_symbols_deinitialize(&symbols); +} + +} // namespace +} // namespace level_zero +} // namespace hal +} // namespace iree diff --git a/experimental/level_zero/event_semaphore.c b/experimental/level_zero/event_semaphore.c new file mode 100644 index 0000000000000..690434456837c --- /dev/null +++ b/experimental/level_zero/event_semaphore.c @@ -0,0 +1,114 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/level_zero/event_semaphore.h" + +#include + +#include "iree/base/api.h" +#include "iree/hal/utils/semaphore_base.h" + +typedef struct iree_hal_level_zero_semaphore_t { + iree_hal_semaphore_t base; + iree_hal_level_zero_context_wrapper_t* context; + iree_atomic_int64_t value; +} iree_hal_level_zero_semaphore_t; + +static const iree_hal_semaphore_vtable_t iree_hal_level_zero_semaphore_vtable; + +static iree_hal_level_zero_semaphore_t* iree_hal_level_zero_semaphore_cast( + iree_hal_semaphore_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_level_zero_semaphore_vtable); + return (iree_hal_level_zero_semaphore_t*)base_value; +} + +iree_status_t iree_hal_level_zero_semaphore_create( + iree_hal_level_zero_context_wrapper_t* context, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) { + IREE_ASSERT_ARGUMENT(context); + IREE_ASSERT_ARGUMENT(out_semaphore); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_level_zero_semaphore_t* semaphore = NULL; + iree_status_t status = iree_allocator_malloc( + context->host_allocator, sizeof(*semaphore), (void**)&semaphore); + if (iree_status_is_ok(status)) { + iree_hal_semaphore_initialize(&iree_hal_level_zero_semaphore_vtable, + &semaphore->base); + iree_atomic_store_int64(&semaphore->value, initial_value, + iree_memory_order_release); + semaphore->context = context; + *out_semaphore = &semaphore->base; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_level_zero_semaphore_destroy( + iree_hal_semaphore_t* base_semaphore) { + iree_hal_level_zero_semaphore_t* semaphore = + iree_hal_level_zero_semaphore_cast(base_semaphore); + iree_allocator_t host_allocator = semaphore->context->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_semaphore_deinitialize(&semaphore->base); + iree_allocator_free(host_allocator, semaphore); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_level_zero_semaphore_query( + iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { + // TODO: Support semaphores completely. + iree_hal_level_zero_semaphore_t* semaphore = + iree_hal_level_zero_semaphore_cast(base_semaphore); + // TODO: Support semaphores completely. + *out_value = + iree_atomic_load_int64(&semaphore->value, iree_memory_order_acquire); + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_semaphore_signal( + iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { + iree_hal_level_zero_semaphore_t* semaphore = + iree_hal_level_zero_semaphore_cast(base_semaphore); + // TODO: Support semaphores completely. Return OK currently as everything is + // synchronized for each submit to allow things to run. + iree_atomic_store_int64(&semaphore->value, new_value, + iree_memory_order_release); + iree_hal_semaphore_poll(&semaphore->base); + return iree_ok_status(); +} + +static void iree_hal_level_zero_semaphore_fail( + iree_hal_semaphore_t* base_semaphore, iree_status_t status) { + iree_hal_level_zero_semaphore_t* semaphore = + iree_hal_level_zero_semaphore_cast(base_semaphore); + // TODO: save status and mark timepoint as failed. + iree_status_ignore(status); + iree_hal_semaphore_poll(&semaphore->base); +} + +static iree_status_t iree_hal_level_zero_semaphore_wait( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_timeout_t timeout) { + iree_hal_level_zero_semaphore_t* semaphore = + iree_hal_level_zero_semaphore_cast(base_semaphore); + // TODO: Support semaphores completely. Return OK currently as everything is + // synchronized for each submit to allow things to run. + iree_hal_semaphore_poll(&semaphore->base); + return iree_ok_status(); +} + +static const iree_hal_semaphore_vtable_t iree_hal_level_zero_semaphore_vtable = + { + .destroy = iree_hal_level_zero_semaphore_destroy, + .query = iree_hal_level_zero_semaphore_query, + .signal = iree_hal_level_zero_semaphore_signal, + .fail = iree_hal_level_zero_semaphore_fail, + .wait = iree_hal_level_zero_semaphore_wait, +}; diff --git a/experimental/level_zero/event_semaphore.h b/experimental/level_zero/event_semaphore.h new file mode 100644 index 0000000000000..3b3f63ecfd49d --- /dev/null +++ b/experimental/level_zero/event_semaphore.h @@ -0,0 +1,30 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_LEVEL_ZERO_SEMAPHORE_H_ +#define IREE_HAL_LEVEL_ZERO_SEMAPHORE_H_ + +#include + +#include "experimental/level_zero/context_wrapper.h" +#include "experimental/level_zero/status_util.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Create a level_zero allocator. +iree_status_t iree_hal_level_zero_semaphore_create( + iree_hal_level_zero_context_wrapper_t* context, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LEVEL_ZERO_SEMAPHORE_H_ diff --git a/experimental/level_zero/level_zero_allocator.c b/experimental/level_zero/level_zero_allocator.c new file mode 100644 index 0000000000000..04d64119b51c3 --- /dev/null +++ b/experimental/level_zero/level_zero_allocator.c @@ -0,0 +1,349 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/level_zero/level_zero_allocator.h" + +#include + +#include "experimental/level_zero/dynamic_symbols.h" +#include "experimental/level_zero/level_zero_buffer.h" +#include "experimental/level_zero/status_util.h" +#include "iree/base/api.h" + +typedef struct iree_hal_level_zero_allocator_t { + iree_hal_resource_t resource; + iree_hal_device_t* base_device; + ze_device_handle_t level_zero_device; + iree_hal_level_zero_context_wrapper_t* context; + + IREE_STATISTICS(iree_hal_allocator_statistics_t statistics;) +} iree_hal_level_zero_allocator_t; + +static const iree_hal_allocator_vtable_t iree_hal_level_zero_allocator_vtable; + +static iree_hal_level_zero_allocator_t* iree_hal_level_zero_allocator_cast( + iree_hal_allocator_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_level_zero_allocator_vtable); + return (iree_hal_level_zero_allocator_t*)base_value; +} + +iree_status_t iree_hal_level_zero_allocator_create( + iree_hal_device_t* base_device, ze_device_handle_t level_zero_device, + iree_hal_level_zero_context_wrapper_t* context, + iree_hal_allocator_t** out_allocator) { + IREE_ASSERT_ARGUMENT(base_device); + IREE_ASSERT_ARGUMENT(context); + IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_level_zero_allocator_t* allocator = NULL; + iree_status_t status = iree_allocator_malloc( + context->host_allocator, sizeof(*allocator), (void**)&allocator); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_level_zero_allocator_vtable, + &allocator->resource); + allocator->context = context; + allocator->base_device = base_device; + allocator->level_zero_device = level_zero_device; + *out_allocator = (iree_hal_allocator_t*)allocator; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_level_zero_allocator_destroy( + iree_hal_allocator_t* IREE_RESTRICT base_allocator) { + iree_hal_level_zero_allocator_t* allocator = + iree_hal_level_zero_allocator_cast(base_allocator); + iree_allocator_t host_allocator = allocator->context->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(host_allocator, allocator); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_allocator_t iree_hal_level_zero_allocator_host_allocator( + const iree_hal_allocator_t* IREE_RESTRICT base_allocator) { + iree_hal_level_zero_allocator_t* allocator = + (iree_hal_level_zero_allocator_t*)base_allocator; + return allocator->context->host_allocator; +} + +static iree_status_t iree_hal_level_zero_allocator_trim( + iree_hal_allocator_t* IREE_RESTRICT base_allocator) { + return iree_ok_status(); +} + +static void iree_hal_level_zero_allocator_query_statistics( + iree_hal_allocator_t* IREE_RESTRICT base_allocator, + iree_hal_allocator_statistics_t* IREE_RESTRICT out_statistics) { + IREE_STATISTICS({ + iree_hal_level_zero_allocator_t* allocator = + iree_hal_level_zero_allocator_cast(base_allocator); + memcpy(out_statistics, &allocator->statistics, sizeof(*out_statistics)); + }); +} + +static iree_status_t iree_hal_level_zero_allocator_query_memory_heaps( + iree_hal_allocator_t* IREE_RESTRICT base_allocator, + iree_host_size_t capacity, + iree_hal_allocator_memory_heap_t* IREE_RESTRICT heaps, + iree_host_size_t* IREE_RESTRICT out_count) { + iree_host_size_t count = 3; + if (out_count) *out_count = count; + if (capacity < count) { + return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); + } + + // Don't think there's a query for these. + // Max allocation size may be much smaller in certain memory types such as + // page-locked memory and it'd be good to enforce that. + const iree_device_size_t max_allocation_size = ~(iree_device_size_t)0; + const iree_device_size_t min_alignment = 64; + + // Device-local memory (dispatch resources): + heaps[0] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + .allowed_usage = + IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_DISPATCH, + .max_allocation_size = max_allocation_size, + .min_alignment = min_alignment, + }; + + // Write-combined page-locked host-local memory (upload): + heaps[1] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_COHERENT, + .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_MAPPING, + .max_allocation_size = max_allocation_size, + .min_alignment = min_alignment, + }; + + // Cached page-locked host-local memory (download): + heaps[2] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_COHERENT | + IREE_HAL_MEMORY_TYPE_HOST_CACHED, + .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_MAPPING, + .max_allocation_size = max_allocation_size, + .min_alignment = min_alignment, + }; + + return iree_ok_status(); +} + +static iree_hal_buffer_compatibility_t +iree_hal_level_zero_allocator_query_buffer_compatibility( + iree_hal_allocator_t* IREE_RESTRICT base_allocator, + iree_hal_buffer_params_t* IREE_RESTRICT params, + iree_device_size_t* IREE_RESTRICT allocation_size) { + // All buffers can be allocated on the heap. + iree_hal_buffer_compatibility_t compatibility = + IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE; + + // LevelZero supports host <-> device for all copies. + if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; + } + + // Buffers can only be used on the queue if they are device visible. + if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { + if (iree_any_bit_set(params->usage, + IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH; + } + } + + // We are now optimal. + params->type &= ~IREE_HAL_MEMORY_TYPE_OPTIMAL; + + // Guard against the corner case where the requested buffer size is 0. The + // application is unlikely to do anything when requesting a 0-byte buffer; but + // it can happen in real world use cases. So we should at least not crash. + if (*allocation_size == 0) *allocation_size = 4; + + return compatibility; +} + +static void iree_hal_level_zero_buffer_free( + iree_hal_level_zero_context_wrapper_t* context, + iree_hal_memory_type_t memory_type, + iree_hal_level_zero_device_ptr_t device_ptr, void* host_ptr) { + if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) { + // Device local. + LEVEL_ZERO_IGNORE_ERROR(context->syms, + zeMemFree(context->level_zero_context, device_ptr)); + } else { + // Host local. + LEVEL_ZERO_IGNORE_ERROR(context->syms, + zeMemFree(context->level_zero_context, host_ptr)); + } +} + +static iree_status_t iree_hal_level_zero_allocator_allocate_buffer( + iree_hal_allocator_t* IREE_RESTRICT base_allocator, + const iree_hal_buffer_params_t* IREE_RESTRICT params, + iree_device_size_t allocation_size, iree_const_byte_span_t initial_data, + iree_hal_buffer_t** IREE_RESTRICT out_buffer) { + iree_hal_level_zero_allocator_t* allocator = + iree_hal_level_zero_allocator_cast(base_allocator); + + // Coerce options into those required by the current device. + iree_hal_buffer_params_t compat_params = *params; + if (!iree_all_bits_set( + iree_hal_level_zero_allocator_query_buffer_compatibility( + base_allocator, &compat_params, &allocation_size), + IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE)) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "allocator cannot allocate a buffer with the given parameters"); + } + + size_t alloc_alignment = 32; + + iree_status_t status = iree_ok_status(); + // Defining device memory alloc. + ze_device_mem_alloc_desc_t memAllocDesc = { + ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC}; + memAllocDesc.flags = ZE_DEVICE_MEM_ALLOC_FLAG_BIAS_CACHED; + memAllocDesc.ordinal = 0; + // Defining host memory alloc. + ze_host_mem_alloc_desc_t hostDesc = {ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC}; + + // Defining memalloc limits. + ze_relaxed_allocation_limits_exp_desc_t exceedCapacity = { + ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC, NULL, + ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE}; + hostDesc.pNext = &exceedCapacity; + memAllocDesc.pNext = &exceedCapacity; + + void* host_ptr = NULL; + iree_hal_level_zero_device_ptr_t device_ptr = NULL; + if (iree_all_bits_set(compat_params.type, + IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) { + // Device local case. + if (iree_all_bits_set(compat_params.type, + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { + status = LEVEL_ZERO_RESULT_TO_STATUS( + allocator->context->syms, + zeMemAllocShared(allocator->context->level_zero_context, + &memAllocDesc, &hostDesc, allocation_size, + alloc_alignment, allocator->level_zero_device, + (void**)&device_ptr)); + host_ptr = (void*)device_ptr; + } else { + // Device only. + status = LEVEL_ZERO_RESULT_TO_STATUS( + allocator->context->syms, + zeMemAllocDevice(allocator->context->level_zero_context, + &memAllocDesc, allocation_size, alloc_alignment, + allocator->level_zero_device, (void**)&device_ptr)); + } + } else { + // Since in Level Zero host memory is visible to device, we can simply + // allocate on host and set device_ptr to point to same data. + status = LEVEL_ZERO_RESULT_TO_STATUS( + allocator->context->syms, + zeMemAllocHost(allocator->context->level_zero_context, &hostDesc, + allocation_size, 64, &host_ptr)); + device_ptr = (iree_hal_level_zero_device_ptr_t)host_ptr; + } + + iree_hal_buffer_t* buffer = NULL; + if (iree_status_is_ok(status)) { + status = iree_hal_level_zero_buffer_wrap( + (iree_hal_allocator_t*)allocator, compat_params.type, + compat_params.access, compat_params.usage, allocation_size, + /*byte_offset=*/0, + /*byte_length=*/allocation_size, device_ptr, host_ptr, &buffer); + } + + // Copy the initial contents into the buffer. This may require staging. + if (iree_status_is_ok(status) && + !iree_const_byte_span_is_empty(initial_data)) { + status = iree_hal_device_transfer_range( + allocator->base_device, + iree_hal_make_host_transfer_buffer_span((void*)initial_data.data, + initial_data.data_length), + 0, iree_hal_make_device_transfer_buffer(buffer), 0, + initial_data.data_length, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, + iree_infinite_timeout()); + } + + if (iree_status_is_ok(status)) { + IREE_STATISTICS(iree_hal_allocator_statistics_record_alloc( + &allocator->statistics, compat_params.type, allocation_size)); + *out_buffer = buffer; + } else { + if (!buffer) { + iree_hal_level_zero_buffer_free(allocator->context, compat_params.type, + device_ptr, host_ptr); + } else { + iree_hal_buffer_release(buffer); + } + } + return status; +} + +static void iree_hal_level_zero_allocator_deallocate_buffer( + iree_hal_allocator_t* IREE_RESTRICT base_allocator, + iree_hal_buffer_t* IREE_RESTRICT base_buffer) { + iree_hal_level_zero_allocator_t* allocator = + iree_hal_level_zero_allocator_cast(base_allocator); + + iree_hal_memory_type_t memory_type = iree_hal_buffer_memory_type(base_buffer); + iree_hal_level_zero_buffer_free( + allocator->context, memory_type, + iree_hal_level_zero_buffer_device_pointer(base_buffer), + iree_hal_level_zero_buffer_host_pointer(base_buffer)); + + IREE_STATISTICS(iree_hal_allocator_statistics_record_free( + &allocator->statistics, memory_type, + iree_hal_buffer_allocation_size(base_buffer))); + + iree_hal_buffer_destroy(base_buffer); +} + +static iree_status_t iree_hal_level_zero_allocator_import_buffer( + iree_hal_allocator_t* IREE_RESTRICT base_allocator, + const iree_hal_buffer_params_t* IREE_RESTRICT params, + iree_hal_external_buffer_t* IREE_RESTRICT external_buffer, + iree_hal_buffer_release_callback_t release_callback, + iree_hal_buffer_t** IREE_RESTRICT out_buffer) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "importing from external buffers not supported"); +} + +static iree_status_t iree_hal_level_zero_allocator_export_buffer( + iree_hal_allocator_t* IREE_RESTRICT base_allocator, + iree_hal_buffer_t* IREE_RESTRICT buffer, + iree_hal_external_buffer_type_t requested_type, + iree_hal_external_buffer_flags_t requested_flags, + iree_hal_external_buffer_t* IREE_RESTRICT out_external_buffer) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "exporting to external buffers not supported"); +} + +static const iree_hal_allocator_vtable_t iree_hal_level_zero_allocator_vtable = + { + .destroy = iree_hal_level_zero_allocator_destroy, + .host_allocator = iree_hal_level_zero_allocator_host_allocator, + .trim = iree_hal_level_zero_allocator_trim, + .query_statistics = iree_hal_level_zero_allocator_query_statistics, + .query_memory_heaps = iree_hal_level_zero_allocator_query_memory_heaps, + .query_buffer_compatibility = + iree_hal_level_zero_allocator_query_buffer_compatibility, + .allocate_buffer = iree_hal_level_zero_allocator_allocate_buffer, + .deallocate_buffer = iree_hal_level_zero_allocator_deallocate_buffer, + .import_buffer = iree_hal_level_zero_allocator_import_buffer, + .export_buffer = iree_hal_level_zero_allocator_export_buffer, +}; diff --git a/experimental/level_zero/level_zero_allocator.h b/experimental/level_zero/level_zero_allocator.h new file mode 100644 index 0000000000000..285e9e99343c9 --- /dev/null +++ b/experimental/level_zero/level_zero_allocator.h @@ -0,0 +1,29 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_LEVEL_ZERO_ALLOCATOR_H_ +#define IREE_HAL_LEVEL_ZERO_ALLOCATOR_H_ + +#include "experimental/level_zero/context_wrapper.h" +#include "experimental/level_zero/status_util.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Create a Level Zero allocator. +iree_status_t iree_hal_level_zero_allocator_create( + iree_hal_device_t* base_device, ze_device_handle_t level_zero_device, + iree_hal_level_zero_context_wrapper_t* context, + iree_hal_allocator_t** out_allocator); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LEVEL_ZERO_ALLOCATOR_H_ diff --git a/experimental/level_zero/level_zero_buffer.c b/experimental/level_zero/level_zero_buffer.c new file mode 100644 index 0000000000000..275ad01aa35b3 --- /dev/null +++ b/experimental/level_zero/level_zero_buffer.c @@ -0,0 +1,140 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/level_zero/level_zero_buffer.h" + +#include +#include +#include + +#include "iree/base/api.h" + +typedef struct iree_hal_level_zero_buffer_t { + iree_hal_buffer_t base; + void* host_ptr; + iree_hal_level_zero_device_ptr_t device_ptr; +} iree_hal_level_zero_buffer_t; + +static const iree_hal_buffer_vtable_t iree_hal_level_zero_buffer_vtable; + +static iree_hal_level_zero_buffer_t* iree_hal_level_zero_buffer_cast( + iree_hal_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_level_zero_buffer_vtable); + return (iree_hal_level_zero_buffer_t*)base_value; +} + +iree_status_t iree_hal_level_zero_buffer_wrap( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size, + iree_device_size_t byte_offset, iree_device_size_t byte_length, + iree_hal_level_zero_device_ptr_t device_ptr, void* host_ptr, + iree_hal_buffer_t** out_buffer) { + IREE_ASSERT_ARGUMENT(allocator); + IREE_ASSERT_ARGUMENT(out_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_t host_allocator = + iree_hal_allocator_host_allocator(allocator); + iree_hal_level_zero_buffer_t* buffer = NULL; + iree_status_t status = + iree_allocator_malloc(host_allocator, sizeof(*buffer), (void**)&buffer); + if (iree_status_is_ok(status)) { + iree_hal_buffer_initialize( + host_allocator, allocator, &buffer->base, allocation_size, byte_offset, + byte_length, memory_type, allowed_access, allowed_usage, + &iree_hal_level_zero_buffer_vtable, &buffer->base); + buffer->host_ptr = host_ptr; + buffer->device_ptr = device_ptr; + *out_buffer = &buffer->base; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_level_zero_buffer_destroy(iree_hal_buffer_t* base_buffer) { + iree_hal_level_zero_buffer_t* buffer = + iree_hal_level_zero_buffer_cast(base_buffer); + iree_allocator_t host_allocator = base_buffer->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + iree_allocator_free(host_allocator, buffer); + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_level_zero_buffer_map_range( + iree_hal_buffer_t* base_buffer, iree_hal_mapping_mode_t mapping_mode, + iree_hal_memory_access_t memory_access, + iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length, + iree_hal_buffer_mapping_t* mapping) { + iree_hal_level_zero_buffer_t* buffer = + iree_hal_level_zero_buffer_cast(base_buffer); + + // TODO(benvanik): add upload/download for unmapped buffers. + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type( + iree_hal_buffer_memory_type(base_buffer), + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_validate_usage(iree_hal_buffer_allowed_usage(base_buffer), + IREE_HAL_BUFFER_USAGE_MAPPING)); + + uint8_t* data_ptr = (uint8_t*)(buffer->host_ptr) + local_byte_offset; + // If we mapped for discard scribble over the bytes. This is not a mandated + // behavior but it will make debugging issues easier. Alternatively for + // heap buffers we could reallocate them such that ASAN yells, but that + // would only work if the entire buffer was discarded. +#ifndef NDEBUG + if (iree_any_bit_set(memory_access, IREE_HAL_MEMORY_ACCESS_DISCARD)) { + memset(data_ptr, 0xCD, local_byte_length); + } +#endif // !NDEBUG + + mapping->contents = iree_make_byte_span(data_ptr, local_byte_length); + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_buffer_unmap_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length, iree_hal_buffer_mapping_t* mapping) { + // Nothing to do (today). + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_buffer_invalidate_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length) { + // Nothing to do. + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_buffer_flush_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length) { + // Nothing to do. + return iree_ok_status(); +} + +iree_hal_level_zero_device_ptr_t iree_hal_level_zero_buffer_device_pointer( + iree_hal_buffer_t* base_buffer) { + iree_hal_level_zero_buffer_t* buffer = + iree_hal_level_zero_buffer_cast(base_buffer); + return buffer->device_ptr; +} + +void* iree_hal_level_zero_buffer_host_pointer(iree_hal_buffer_t* base_buffer) { + iree_hal_level_zero_buffer_t* buffer = + iree_hal_level_zero_buffer_cast(base_buffer); + return buffer->host_ptr; +} + +static const iree_hal_buffer_vtable_t iree_hal_level_zero_buffer_vtable = { + .recycle = iree_hal_buffer_recycle, + .destroy = iree_hal_level_zero_buffer_destroy, + .map_range = iree_hal_level_zero_buffer_map_range, + .unmap_range = iree_hal_level_zero_buffer_unmap_range, + .invalidate_range = iree_hal_level_zero_buffer_invalidate_range, + .flush_range = iree_hal_level_zero_buffer_flush_range, +}; diff --git a/experimental/level_zero/level_zero_buffer.h b/experimental/level_zero/level_zero_buffer.h new file mode 100644 index 0000000000000..375ae6dbe694a --- /dev/null +++ b/experimental/level_zero/level_zero_buffer.h @@ -0,0 +1,42 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_LEVEL_ZERO_BUFFER_H_ +#define IREE_HAL_LEVEL_ZERO_BUFFER_H_ + +#include "experimental/level_zero/level_zero_headers.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef void* iree_hal_level_zero_device_ptr_t; + +// Wraps a Level Zero allocation in an iree_hal_buffer_t. +iree_status_t iree_hal_level_zero_buffer_wrap( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size, + iree_device_size_t byte_offset, iree_device_size_t byte_length, + iree_hal_level_zero_device_ptr_t device_ptr, void* host_ptr, + iree_hal_buffer_t** out_buffer); + +// Returns the Level Zero base pointer for the given |buffer|. +// This is the entire allocated_buffer and must be offset by the buffer +// byte_offset and byte_length when used. +iree_hal_level_zero_device_ptr_t iree_hal_level_zero_buffer_device_pointer( + iree_hal_buffer_t* buffer); + +// Returns the Level Zero host pointer for the given |buffer|, if available. +void* iree_hal_level_zero_buffer_host_pointer(iree_hal_buffer_t* buffer); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LEVEL_ZERO_BUFFER_H_ diff --git a/experimental/level_zero/level_zero_device.c b/experimental/level_zero/level_zero_device.c new file mode 100644 index 0000000000000..d27abd164e69a --- /dev/null +++ b/experimental/level_zero/level_zero_device.c @@ -0,0 +1,427 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/level_zero/level_zero_device.h" + +#include +#include +#include + +#include "experimental/level_zero/context_wrapper.h" +#include "experimental/level_zero/direct_command_buffer.h" +#include "experimental/level_zero/dynamic_symbols.h" +#include "experimental/level_zero/event_semaphore.h" +#include "experimental/level_zero/level_zero_allocator.h" +#include "experimental/level_zero/level_zero_event.h" +#include "experimental/level_zero/nop_executable_cache.h" +#include "experimental/level_zero/pipeline_layout.h" +#include "experimental/level_zero/status_util.h" +#include "iree/base/internal/arena.h" +#include "iree/hal/utils/buffer_transfer.h" + +//===----------------------------------------------------------------------===// +// iree_hal_level_zero_device_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_level_zero_device_t { + iree_hal_resource_t resource; + iree_string_view_t identifier; + + // Block pool used for command buffers with a larger block size (as command + // buffers can contain inlined data uploads). + iree_arena_block_pool_t block_pool; + + // Optional driver that owns the Level Zero symbols. We retain it for our + // lifetime to ensure the symbols remains valid. + iree_hal_driver_t* driver; + + // Level Zero APIs. + ze_device_handle_t device; + uint32_t command_queue_ordinal; + ze_command_queue_handle_t command_queue; + ze_event_pool_handle_t event_pool; + + iree_hal_level_zero_context_wrapper_t context_wrapper; + iree_hal_allocator_t* device_allocator; + iree_hal_event_t* event; + +} iree_hal_level_zero_device_t; + +static const iree_hal_device_vtable_t iree_hal_level_zero_device_vtable; + +static iree_hal_level_zero_device_t* iree_hal_level_zero_device_cast( + iree_hal_device_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_level_zero_device_vtable); + return (iree_hal_level_zero_device_t*)base_value; +} + +static void iree_hal_level_zero_device_destroy(iree_hal_device_t* base_device) { + iree_hal_level_zero_device_t* device = + iree_hal_level_zero_device_cast(base_device); + iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); + IREE_TRACE_ZONE_BEGIN(z0); + + // There should be no more buffers live that use the allocator. + iree_hal_allocator_release(device->device_allocator); + LEVEL_ZERO_IGNORE_ERROR(device->context_wrapper.syms, + zeCommandQueueDestroy(device->command_queue)); + // Finally, destroy the device. + iree_hal_driver_release(device->driver); + + iree_allocator_free(host_allocator, device); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_level_zero_device_create_internal( + iree_hal_driver_t* driver, iree_string_view_t identifier, + ze_device_handle_t level_zero_device, uint32_t command_queue_ordinal, + ze_command_queue_handle_t command_queue, ze_event_pool_handle_t event_pool, + ze_context_handle_t level_zero_context, + iree_hal_level_zero_dynamic_symbols_t* syms, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + iree_hal_level_zero_device_t* device = NULL; + iree_host_size_t total_size = sizeof(*device) + identifier.size; + IREE_RETURN_IF_ERROR( + iree_allocator_malloc(host_allocator, total_size, (void**)&device)); + memset(device, 0, total_size); + iree_hal_resource_initialize(&iree_hal_level_zero_device_vtable, + &device->resource); + device->driver = driver; + iree_hal_driver_retain(device->driver); + uint8_t* buffer_ptr = (uint8_t*)device + sizeof(*device); + buffer_ptr += iree_string_view_append_to_buffer( + identifier, &device->identifier, (char*)buffer_ptr); + device->device = level_zero_device; + device->command_queue_ordinal = command_queue_ordinal; + device->command_queue = command_queue; + device->event_pool = event_pool; + device->context_wrapper.level_zero_context = level_zero_context; + device->context_wrapper.host_allocator = host_allocator; + device->context_wrapper.syms = syms; + iree_status_t status = iree_hal_level_zero_allocator_create( + (iree_hal_device_t*)device, device->device, &device->context_wrapper, + &device->device_allocator); + if (iree_status_is_ok(status)) { + *out_device = (iree_hal_device_t*)device; + } else { + iree_hal_device_release((iree_hal_device_t*)device); + } + return status; +} + +iree_status_t iree_hal_level_zero_device_create( + iree_hal_driver_t* driver, iree_string_view_t identifier, + iree_hal_level_zero_dynamic_symbols_t* syms, + ze_device_handle_t level_zero_device, + ze_context_handle_t level_zero_context, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + IREE_TRACE_ZONE_BEGIN(z0); + // Create a command queue + uint32_t num_queue_groups = 0; + iree_status_t status = LEVEL_ZERO_RESULT_TO_STATUS( + syms, + zeDeviceGetCommandQueueGroupProperties(level_zero_device, + &num_queue_groups, NULL), + "zeDeviceGetCommandQueueGroupProperties"); + if (num_queue_groups == 0) { + return iree_make_status(IREE_STATUS_NOT_FOUND, "No queue groups found"); + } + ze_command_queue_group_properties_t* queue_properties = + (ze_command_queue_group_properties_t*)malloc( + num_queue_groups * sizeof(ze_command_queue_group_properties_t)); + for (uint32_t i = 0; i < num_queue_groups; ++i) { + queue_properties[i].stype = + ZE_STRUCTURE_TYPE_COMMAND_QUEUE_GROUP_PROPERTIES; + } + status = LEVEL_ZERO_RESULT_TO_STATUS( + syms, + zeDeviceGetCommandQueueGroupProperties( + level_zero_device, &num_queue_groups, queue_properties), + "zeDeviceGetCommandQueueGroupProperties"); + + ze_command_queue_desc_t command_queue_desc = { + .stype = ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC, + .index = 0, + .mode = ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS}; + for (uint32_t i = 0; i < num_queue_groups; i++) { + if (queue_properties[i].flags & + ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE) { + command_queue_desc.ordinal = i; + } + } + ze_command_queue_handle_t command_queue; + status = LEVEL_ZERO_RESULT_TO_STATUS( + syms, + zeCommandQueueCreate(level_zero_context, level_zero_device, + &command_queue_desc, &command_queue), + "zeCommandQueueCreate"); + + // Create a event pool. + ze_event_pool_desc_t event_pool_desc = {}; + event_pool_desc.stype = ZE_STRUCTURE_TYPE_EVENT_POOL_DESC; + event_pool_desc.count = 1; + event_pool_desc.flags = ZE_EVENT_POOL_FLAG_HOST_VISIBLE; + ze_event_pool_handle_t event_pool; + status = LEVEL_ZERO_RESULT_TO_STATUS( + syms, + zeEventPoolCreate(level_zero_context, &event_pool_desc, 1, + &level_zero_device, &event_pool), + "zeEventPoolCreate"); + + // Create HAL-LevelZero device. + if (iree_status_is_ok(status)) { + status = iree_hal_level_zero_device_create_internal( + driver, identifier, level_zero_device, command_queue_desc.ordinal, + command_queue, event_pool, level_zero_context, syms, host_allocator, + out_device); + } + + // Create an event handle. + iree_hal_event_t* event; + iree_hal_level_zero_device_t* cast_device = + iree_hal_level_zero_device_cast(*out_device); + status = iree_hal_level_zero_event_create(&cast_device->context_wrapper, + cast_device->event_pool, &event); + cast_device->event = event; + *out_device = (iree_hal_device_t*)cast_device; + + if (!iree_status_is_ok(status)) { + syms->zeCommandQueueDestroy(command_queue); + syms->zeEventPoolDestroy(event_pool); + syms->zeContextDestroy(level_zero_context); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_string_view_t iree_hal_level_zero_device_id( + iree_hal_device_t* base_device) { + iree_hal_level_zero_device_t* device = + iree_hal_level_zero_device_cast(base_device); + return device->identifier; +} + +static iree_allocator_t iree_hal_level_zero_device_host_allocator( + iree_hal_device_t* base_device) { + iree_hal_level_zero_device_t* device = + iree_hal_level_zero_device_cast(base_device); + return device->context_wrapper.host_allocator; +} + +static iree_hal_allocator_t* iree_hal_level_zero_device_allocator( + iree_hal_device_t* base_device) { + iree_hal_level_zero_device_t* device = + iree_hal_level_zero_device_cast(base_device); + return device->device_allocator; +} + +static void iree_hal_level_zero_replace_device_allocator( + iree_hal_device_t* base_device, iree_hal_allocator_t* new_allocator) { + iree_hal_level_zero_device_t* device = + iree_hal_level_zero_device_cast(base_device); + iree_hal_allocator_retain(new_allocator); + iree_hal_allocator_release(device->device_allocator); + device->device_allocator = new_allocator; +} + +static iree_status_t iree_hal_level_zero_device_query_i64( + iree_hal_device_t* base_device, iree_string_view_t category, + iree_string_view_t key, int64_t* out_value) { + // iree_hal_level_zero_device_t* device = + // iree_hal_level_zero_device_cast(base_device); + *out_value = 0; + + if (iree_string_view_equal(category, + iree_make_cstring_view("hal.executable.format"))) { + *out_value = + iree_string_view_equal(key, iree_make_cstring_view("opencl-spirv-fb")) + ? 1 + : 0; + return iree_ok_status(); + } + + return iree_make_status( + IREE_STATUS_NOT_FOUND, + "unknown device configuration key value '%.*s :: %.*s'", + (int)category.size, category.data, (int)key.size, key.data); +} + +static iree_status_t iree_hal_level_zero_device_trim( + iree_hal_device_t* base_device) { + iree_hal_level_zero_device_t* device = + iree_hal_level_zero_device_cast(base_device); + iree_arena_block_pool_trim(&device->block_pool); + return iree_hal_allocator_trim(device->device_allocator); +} + +static iree_status_t iree_hal_level_zero_device_create_command_buffer( + iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, + iree_hal_command_buffer_t** out_command_buffer) { + iree_hal_level_zero_device_t* device = + iree_hal_level_zero_device_cast(base_device); + return iree_hal_level_zero_direct_command_buffer_create( + base_device, &device->context_wrapper, mode, command_categories, + queue_affinity, binding_capacity, &device->block_pool, device->device, + device->command_queue_ordinal, out_command_buffer); +} + +static iree_status_t iree_hal_level_zero_device_create_descriptor_set_layout( + iree_hal_device_t* base_device, + iree_hal_descriptor_set_layout_flags_t flags, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { + iree_hal_level_zero_device_t* device = + iree_hal_level_zero_device_cast(base_device); + return iree_hal_level_zero_descriptor_set_layout_create( + &device->context_wrapper, flags, binding_count, bindings, + out_descriptor_set_layout); +} + +static iree_status_t iree_hal_level_zero_device_create_event( + iree_hal_device_t* base_device, iree_hal_event_t** out_event) { + iree_hal_level_zero_device_t* device = + iree_hal_level_zero_device_cast(base_device); + return iree_hal_level_zero_event_create(&device->context_wrapper, + device->event_pool, out_event); +} + +static iree_status_t iree_hal_level_zero_device_create_executable_cache( + iree_hal_device_t* base_device, iree_string_view_t identifier, + iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache) { + iree_hal_level_zero_device_t* device = + iree_hal_level_zero_device_cast(base_device); + return iree_hal_level_zero_nop_executable_cache_create( + &device->context_wrapper, identifier, device->device, + out_executable_cache); +} + +static iree_status_t iree_hal_level_zero_device_create_pipeline_layout( + iree_hal_device_t* base_device, iree_host_size_t push_constants, + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t* const* set_layouts, + iree_hal_pipeline_layout_t** out_pipeline_layout) { + iree_hal_level_zero_device_t* device = + iree_hal_level_zero_device_cast(base_device); + return iree_hal_level_zero_pipeline_layout_create( + &device->context_wrapper, set_layout_count, set_layouts, push_constants, + out_pipeline_layout); +} + +static iree_status_t iree_hal_level_zero_device_create_semaphore( + iree_hal_device_t* base_device, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) { + iree_hal_level_zero_device_t* device = + iree_hal_level_zero_device_cast(base_device); + return iree_hal_level_zero_semaphore_create(&device->context_wrapper, + initial_value, out_semaphore); +} + +static iree_hal_semaphore_compatibility_t +iree_hal_level_zero_device_query_semaphore_compatibility( + iree_hal_device_t* base_device, iree_hal_semaphore_t* semaphore) { + // TODO: implement Level Zero semaphores. + return IREE_HAL_SEMAPHORE_COMPATIBILITY_HOST_ONLY; +} + +static iree_status_t iree_hal_level_zero_device_queue_alloca( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_allocator_pool_t pool, iree_hal_buffer_params_t params, + iree_device_size_t allocation_size, + iree_hal_buffer_t** IREE_RESTRICT out_buffer) { + // TODO(benvanik): queue-ordered allocations. + IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_wait(wait_semaphore_list, + iree_infinite_timeout())); + IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( + iree_hal_device_allocator(base_device), params, allocation_size, + iree_const_byte_span_empty(), out_buffer)); + IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_signal(signal_semaphore_list)); + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_device_queue_dealloca( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* buffer) { + // TODO(benvanik): queue-ordered allocations. + IREE_RETURN_IF_ERROR(iree_hal_device_queue_barrier( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list)); + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_device_queue_execute( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_host_size_t command_buffer_count, + iree_hal_command_buffer_t* const* command_buffers) { + iree_hal_level_zero_device_t* device = + iree_hal_level_zero_device_cast(base_device); + // TODO(raikonenfnu): Once semaphore is implemented wait for semaphores + // TODO(thomasraoux): implement semaphores - for now this conservatively + // synchronizes after every submit. + for (int i = 0; i < command_buffer_count; i++) { + iree_hal_command_buffer_t* command_buffer = command_buffers[i]; + ze_command_list_handle_t command_list = + iree_hal_level_zero_direct_command_buffer_exec(command_buffer); + LEVEL_ZERO_RETURN_IF_ERROR( + device->context_wrapper.syms, + zeCommandListAppendSignalEvent( + command_list, iree_hal_level_zero_event_handle(device->event)), + "zeCommandListAppendSignalEvent"); + LEVEL_ZERO_RETURN_IF_ERROR( + device->context_wrapper.syms, + zeEventHostSynchronize(iree_hal_level_zero_event_handle(device->event), + IREE_DURATION_INFINITE), + "zeEventHostSynchronize"); + } + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_device_queue_flush( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) { + // Currently unused; we flush as submissions are made. + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_device_wait_semaphores( + iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "semaphore not implemented"); +} + +static const iree_hal_device_vtable_t iree_hal_level_zero_device_vtable = { + .destroy = iree_hal_level_zero_device_destroy, + .id = iree_hal_level_zero_device_id, + .host_allocator = iree_hal_level_zero_device_host_allocator, + .device_allocator = iree_hal_level_zero_device_allocator, + .replace_device_allocator = iree_hal_level_zero_replace_device_allocator, + .trim = iree_hal_level_zero_device_trim, + .query_i64 = iree_hal_level_zero_device_query_i64, + .create_command_buffer = iree_hal_level_zero_device_create_command_buffer, + .create_descriptor_set_layout = + iree_hal_level_zero_device_create_descriptor_set_layout, + .create_event = iree_hal_level_zero_device_create_event, + .create_executable_cache = + iree_hal_level_zero_device_create_executable_cache, + .create_pipeline_layout = iree_hal_level_zero_device_create_pipeline_layout, + .create_semaphore = iree_hal_level_zero_device_create_semaphore, + .query_semaphore_compatibility = + iree_hal_level_zero_device_query_semaphore_compatibility, + .transfer_range = iree_hal_device_submit_transfer_range_and_wait, + .queue_alloca = iree_hal_level_zero_device_queue_alloca, + .queue_dealloca = iree_hal_level_zero_device_queue_dealloca, + .queue_execute = iree_hal_level_zero_device_queue_execute, + .queue_flush = iree_hal_level_zero_device_queue_flush, + .wait_semaphores = iree_hal_level_zero_device_wait_semaphores, +}; diff --git a/experimental/level_zero/level_zero_device.h b/experimental/level_zero/level_zero_device.h new file mode 100644 index 0000000000000..33c36b7511a46 --- /dev/null +++ b/experimental/level_zero/level_zero_device.h @@ -0,0 +1,31 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_LEVEL_ZERO_LEVEL_ZERO_DEVICE_H_ +#define IREE_HAL_LEVEL_ZERO_LEVEL_ZERO_DEVICE_H_ + +#include "experimental/level_zero/api.h" +#include "experimental/level_zero/dynamic_symbols.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a device that owns and manages its own hipContext. +iree_status_t iree_hal_level_zero_device_create( + iree_hal_driver_t* driver, iree_string_view_t identifier, + iree_hal_level_zero_dynamic_symbols_t* syms, + ze_device_handle_t level_zero_device, + ze_context_handle_t level_zero_context, iree_allocator_t host_allocator, + iree_hal_device_t** out_device); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LEVEL_ZERO_LEVEL_ZERO_DEVICE_H_ diff --git a/experimental/level_zero/level_zero_driver.c b/experimental/level_zero/level_zero_driver.c new file mode 100644 index 0000000000000..06913bc719356 --- /dev/null +++ b/experimental/level_zero/level_zero_driver.c @@ -0,0 +1,370 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "experimental/level_zero/api.h" +#include "experimental/level_zero/dynamic_symbols.h" +#include "experimental/level_zero/level_zero_device.h" +#include "experimental/level_zero/status_util.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +typedef struct iree_hal_level_zero_driver_t { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; + // Identifier used for the driver in the IREE driver registry. + // We allow overriding so that multiple LevelZero versions can be exposed in + // the same process. + iree_string_view_t identifier; + int default_device_index; + + // Level Zero Driver Handle. + ze_driver_handle_t driver_handle; + ze_context_handle_t context; + // LevelZero symbols. + iree_hal_level_zero_dynamic_symbols_t syms; +} iree_hal_level_zero_driver_t; + +// Pick a fixed lenght size for device names. +#define IREE_MAX_LEVEL_ZERO_DEVICE_NAME_LENGTH ZE_MAX_DEVICE_NAME + +static const iree_hal_driver_vtable_t iree_hal_level_zero_driver_vtable; + +static iree_hal_level_zero_driver_t* iree_hal_level_zero_driver_cast( + iree_hal_driver_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_level_zero_driver_vtable); + return (iree_hal_level_zero_driver_t*)base_value; +} + +IREE_API_EXPORT void iree_hal_level_zero_driver_options_initialize( + iree_hal_level_zero_driver_options_t* out_options) { + memset(out_options, 0, sizeof(*out_options)); + out_options->default_device_index = 0; +} + +static iree_status_t iree_hal_level_zero_driver_create_internal( + iree_string_view_t identifier, + const iree_hal_level_zero_driver_options_t* options, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + iree_hal_level_zero_driver_t* driver = NULL; + iree_host_size_t total_size = sizeof(*driver) + identifier.size; + IREE_RETURN_IF_ERROR( + iree_allocator_malloc(host_allocator, total_size, (void**)&driver)); + iree_hal_resource_initialize(&iree_hal_level_zero_driver_vtable, + &driver->resource); + driver->host_allocator = host_allocator; + iree_string_view_append_to_buffer( + identifier, &driver->identifier, + (char*)driver + total_size - identifier.size); + driver->default_device_index = options->default_device_index; + iree_status_t status = iree_hal_level_zero_dynamic_symbols_initialize( + host_allocator, &driver->syms); + if (iree_status_is_ok(status)) { + // Initialize Level Zero + LEVEL_ZERO_RETURN_IF_ERROR(&driver->syms, zeInit(0), "zeInit"); + // Get the driver + uint32_t driver_count = 0; + LEVEL_ZERO_RETURN_IF_ERROR(&driver->syms, zeDriverGet(&driver_count, NULL), + "zeDriverGet"); + ze_driver_handle_t driver_handle; + LEVEL_ZERO_RETURN_IF_ERROR(&driver->syms, + zeDriverGet(&driver_count, &driver_handle), + "zeDriverGet"); + driver->driver_handle = driver_handle; + } + if (iree_status_is_ok(status)) { + *out_driver = (iree_hal_driver_t*)driver; + } else { + iree_hal_driver_release((iree_hal_driver_t*)driver); + } + return status; +} + +static void iree_hal_level_zero_driver_destroy(iree_hal_driver_t* base_driver) { + iree_hal_level_zero_driver_t* driver = + iree_hal_level_zero_driver_cast(base_driver); + iree_allocator_t host_allocator = driver->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_level_zero_dynamic_symbols_deinitialize(&driver->syms); + iree_allocator_free(host_allocator, driver); + + IREE_TRACE_ZONE_END(z0); +} + +IREE_API_EXPORT iree_status_t iree_hal_level_zero_driver_create( + iree_string_view_t identifier, + const iree_hal_level_zero_driver_options_t* options, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(options); + IREE_ASSERT_ARGUMENT(out_driver); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_hal_level_zero_driver_create_internal( + identifier, options, host_allocator, out_driver); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +#define IREE_HAL_LEVEL_ZERO_DEVICE_UUID_TEXT_LENGTH 36 + +// Populates device information from the given Level Zero physical device +// handle. |out_device_info| must point to valid memory and additional data will +// be appended to |buffer_ptr| and the new pointer is returned. +// Puts the device UUID returned from Level Zero into |out_device_info->path| +// in a UUID canonical textual representation. +static uint8_t* iree_hal_level_zero_populate_device_info( + ze_device_handle_t device, iree_hal_level_zero_dynamic_symbols_t* syms, + uint8_t* buffer_ptr, iree_hal_device_info_t* out_device_info) { + ze_device_properties_t deviceProperties = { + .stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES}; + LEVEL_ZERO_IGNORE_ERROR(syms, + zeDeviceGetProperties(device, &deviceProperties)); + memset(out_device_info, 0, sizeof(*out_device_info)); + out_device_info->device_id = (iree_hal_device_id_t)device; + + iree_string_view_t device_name_string = iree_make_string_view( + deviceProperties.name, strlen(deviceProperties.name)); + buffer_ptr += iree_string_view_append_to_buffer( + device_name_string, &out_device_info->name, (char*)buffer_ptr); + + // Get device UUID. + const uint8_t* device_uuid = deviceProperties.uuid.id; + char device_path_str[IREE_HAL_LEVEL_ZERO_DEVICE_UUID_TEXT_LENGTH + 1] = {0}; + snprintf(device_path_str, sizeof(device_path_str), + "%02x%02x%02x%02x-" + "%02x%02x-" + "%02x%02x-" + "%02x%02x-" + "%02x%02x%02x%02x%02x%02x", + device_uuid[0], device_uuid[1], device_uuid[2], device_uuid[3], + device_uuid[4], device_uuid[5], device_uuid[6], device_uuid[7], + device_uuid[8], device_uuid[9], device_uuid[10], device_uuid[11], + device_uuid[12], device_uuid[13], device_uuid[14], device_uuid[15]); + iree_string_view_t device_path = iree_make_string_view( + device_path_str, IREE_ARRAYSIZE(device_path_str) - 1); + buffer_ptr += iree_string_view_append_to_buffer( + device_path, &out_device_info->path, (char*)buffer_ptr); + + return buffer_ptr; +} + +static iree_status_t iree_hal_level_zero_driver_query_available_devices( + iree_hal_driver_t* base_driver, iree_allocator_t host_allocator, + iree_host_size_t* out_device_info_count, + iree_hal_device_info_t** out_device_infos) { + iree_hal_level_zero_driver_t* driver = + iree_hal_level_zero_driver_cast(base_driver); + // Query the number of available Level Zero devices. + uint32_t device_count = 0; + LEVEL_ZERO_RETURN_IF_ERROR( + &driver->syms, zeDeviceGet(driver->driver_handle, &device_count, NULL), + "zeDeviceGet"); + + // Allocate the return infos and populate with the devices. + iree_hal_device_info_t* device_infos = NULL; + iree_host_size_t total_size = + device_count * (sizeof(iree_hal_device_info_t) + + (IREE_HAL_LEVEL_ZERO_DEVICE_UUID_TEXT_LENGTH + 1 + + IREE_MAX_LEVEL_ZERO_DEVICE_NAME_LENGTH) * + sizeof(char)); + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&device_infos); + if (iree_status_is_ok(status)) { + uint8_t* buffer_ptr = + (uint8_t*)device_infos + device_count * sizeof(iree_hal_device_info_t); + ze_device_handle_t* device_list = + (ze_device_handle_t*)malloc(device_count * sizeof(ze_device_handle_t)); + status = LEVEL_ZERO_RESULT_TO_STATUS( + &driver->syms, + zeDeviceGet(driver->driver_handle, &device_count, device_list), + "zeDeviceGet"); + for (iree_host_size_t i = 0; i < device_count; ++i) { + if (!iree_status_is_ok(status)) break; + buffer_ptr = iree_hal_level_zero_populate_device_info( + device_list[i], &driver->syms, buffer_ptr, &device_infos[i]); + } + free(device_list); + } + if (iree_status_is_ok(status)) { + *out_device_info_count = device_count; + *out_device_infos = device_infos; + } else { + iree_allocator_free(host_allocator, device_infos); + } + return status; +} + +static iree_status_t iree_hal_level_zero_driver_dump_device_info( + iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, + iree_string_builder_t* builder) { + iree_hal_level_zero_driver_t* driver = + iree_hal_level_zero_driver_cast(base_driver); + ze_device_handle_t device = (ze_device_handle_t)device_id; + if (!device) return iree_ok_status(); + // TODO: dump detailed device info. + (void)driver; + (void)device; + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_driver_select_default_device( + iree_hal_level_zero_dynamic_symbols_t* syms, int default_device_index, + iree_allocator_t host_allocator, ze_driver_handle_t driver_handle, + ze_device_handle_t* out_device) { + uint32_t device_count = 0; + LEVEL_ZERO_RETURN_IF_ERROR( + syms, zeDeviceGet(driver_handle, &device_count, NULL), "zeDeviceGet"); + iree_status_t status = iree_ok_status(); + if (device_count == 0 || default_device_index >= device_count) { + status = iree_make_status(IREE_STATUS_NOT_FOUND, + "default device %d not found (of %d enumerated)", + default_device_index, device_count); + } else { + ze_device_handle_t* device_list = + (ze_device_handle_t*)malloc(device_count * sizeof(ze_device_handle_t)); + status = LEVEL_ZERO_RESULT_TO_STATUS( + syms, zeDeviceGet(driver_handle, &device_count, device_list), + "zeDeviceGet"); + *out_device = device_list[default_device_index]; + } + return status; +} + +static iree_status_t iree_hal_level_zero_driver_create_device_by_id( + iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, + iree_host_size_t param_count, const iree_string_pair_t* params, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + iree_hal_level_zero_driver_t* driver = + iree_hal_level_zero_driver_cast(base_driver); + IREE_TRACE_ZONE_BEGIN(z0); + + // Use either the specified device (enumerated earlier) or whatever default + // one was specified when the driver was created. + ze_device_handle_t device = (ze_device_handle_t)device_id; + if (device == 0) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_level_zero_driver_select_default_device( + &driver->syms, driver->default_device_index, host_allocator, + driver->driver_handle, &device)); + } + ze_context_desc_t context_description = {}; + context_description.stype = ZE_STRUCTURE_TYPE_CONTEXT_DESC; + ze_context_handle_t context; + LEVEL_ZERO_RETURN_IF_ERROR( + &driver->syms, + zeContextCreate(driver->driver_handle, &context_description, &context), + "zeContextCreate"); + iree_string_view_t device_name = iree_make_cstring_view("level_zero"); + + // Attempt to create the device. + iree_status_t status = iree_hal_level_zero_device_create( + base_driver, device_name, &driver->syms, device, context, host_allocator, + out_device); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static bool uuids_equal(const uint8_t* id1, const uint8_t* id2) { + return memcmp(id1, id2, + 16 < ZE_MAX_DEVICE_UUID_SIZE ? 16 : ZE_MAX_DEVICE_UUID_SIZE) == + 0; +} + +static iree_status_t iree_hal_level_zero_driver_create_device_by_uuid( + iree_hal_driver_t* base_driver, iree_string_view_t driver_name, + const uint8_t* device_uuid, iree_host_size_t param_count, + const iree_string_pair_t* params, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + iree_hal_level_zero_driver_t* driver = + iree_hal_level_zero_driver_cast(base_driver); + ze_device_handle_t* ze_devices = NULL; + iree_status_t status; + + // Get the number of Level Zero devices. + uint32_t device_count = 0; + IREE_LEVEL_ZERO_TRY(LEVEL_ZERO_RESULT_TO_STATUS( + &driver->syms, zeDeviceGet(driver->driver_handle, &device_count, NULL), + "zeDeviceGet")); + + // Get all Level Zero devices. + iree_allocator_malloc(driver->host_allocator, + sizeof(ze_device_handle_t) * device_count, + (void**)&ze_devices); + IREE_LEVEL_ZERO_TRY(LEVEL_ZERO_RESULT_TO_STATUS( + &driver->syms, + zeDeviceGet(driver->driver_handle, &device_count, ze_devices), + "zeDeviceGet")); + + // Find the Level Zero device with the given UUID. + bool is_device_found = false; + for (uint32_t i = 0; i < device_count; ++i) { + ze_device_properties_t device_properties = { + .stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES}; + IREE_LEVEL_ZERO_TRY(LEVEL_ZERO_RESULT_TO_STATUS( + &driver->syms, zeDeviceGetProperties(ze_devices[i], &device_properties), + "zeDeviceGetProperties")); + if (uuids_equal(device_uuid, device_properties.uuid.id)) { + IREE_LEVEL_ZERO_TRY(iree_hal_level_zero_driver_create_device_by_id( + base_driver, (uintptr_t)ze_devices[i], param_count, params, + host_allocator, out_device)); + is_device_found = true; + break; + } + } + + if (!is_device_found) { + status = iree_make_status(IREE_STATUS_NOT_FOUND, + "Could not find Level Zero device by UUID."); + } + +cleanup: + iree_allocator_free(driver->host_allocator, ze_devices); + + return status; +} + +static iree_status_t iree_hal_level_zero_driver_create_device_by_path( + iree_hal_driver_t* base_driver, iree_string_view_t driver_name, + iree_string_view_t device_path, iree_host_size_t param_count, + const iree_string_pair_t* params, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + if (iree_string_view_is_empty(device_path)) { + return iree_hal_level_zero_driver_create_device_by_id( + base_driver, IREE_HAL_DEVICE_ID_DEFAULT, param_count, params, + host_allocator, out_device); + } + + // Try parsing as a device UUID. + uint8_t device_uuid[16] = {0}; + if (iree_string_view_parse_hex_bytes(device_path, 16, device_uuid)) { + return iree_hal_level_zero_driver_create_device_by_uuid( + base_driver, driver_name, device_uuid, param_count, params, + host_allocator, out_device); + } + + uint64_t device_id = 0; + if (iree_string_view_atoi_uint64(device_path, &device_id)) { + return iree_hal_level_zero_driver_create_device_by_id( + base_driver, (uintptr_t)device_id, param_count, params, host_allocator, + out_device); + } + + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unsupported device path"); +} + +static const iree_hal_driver_vtable_t iree_hal_level_zero_driver_vtable = { + .destroy = iree_hal_level_zero_driver_destroy, + .query_available_devices = + iree_hal_level_zero_driver_query_available_devices, + .dump_device_info = iree_hal_level_zero_driver_dump_device_info, + .create_device_by_id = iree_hal_level_zero_driver_create_device_by_id, + .create_device_by_path = iree_hal_level_zero_driver_create_device_by_path, +}; diff --git a/experimental/level_zero/level_zero_event.c b/experimental/level_zero/level_zero_event.c new file mode 100644 index 0000000000000..0ada8616ed064 --- /dev/null +++ b/experimental/level_zero/level_zero_event.c @@ -0,0 +1,78 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/level_zero/level_zero_event.h" + +#include + +#include "iree/base/api.h" + +// Dummy events for now, don't do anything. +typedef struct iree_hal_level_zero_event_t { + iree_hal_resource_t resource; + iree_hal_level_zero_context_wrapper_t* context_wrapper; + ze_event_handle_t handle; +} iree_hal_level_zero_event_t; + +static const iree_hal_event_vtable_t iree_hal_level_zero_event_vtable; + +static iree_hal_level_zero_event_t* iree_hal_level_zero_event_cast( + iree_hal_event_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_level_zero_event_vtable); + return (iree_hal_level_zero_event_t*)base_value; +} + +iree_status_t iree_hal_level_zero_event_create( + iree_hal_level_zero_context_wrapper_t* context_wrapper, + ze_event_pool_handle_t event_pool, iree_hal_event_t** out_event) { + IREE_ASSERT_ARGUMENT(context_wrapper); + IREE_ASSERT_ARGUMENT(out_event); + *out_event = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_level_zero_event_t* event = NULL; + iree_status_t status = iree_allocator_malloc(context_wrapper->host_allocator, + sizeof(*event), (void**)&event); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_level_zero_event_vtable, + &event->resource); + ze_event_desc_t event_desc = {}; + event_desc.stype = ZE_STRUCTURE_TYPE_EVENT_DESC; + event_desc.signal = ZE_EVENT_SCOPE_FLAG_HOST; + event_desc.wait = ZE_EVENT_SCOPE_FLAG_HOST; + ze_event_handle_t handle; + LEVEL_ZERO_RETURN_IF_ERROR(context_wrapper->syms, + zeEventCreate(event_pool, &event_desc, &handle), + "zeEventCreate"); + event->handle = handle; + event->context_wrapper = context_wrapper; + *out_event = (iree_hal_event_t*)event; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_level_zero_event_destroy(iree_hal_event_t* base_event) { + iree_hal_level_zero_event_t* event = + iree_hal_level_zero_event_cast(base_event); + iree_allocator_t host_allocator = event->context_wrapper->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + LEVEL_ZERO_IGNORE_ERROR(event->context_wrapper->syms, + zeEventDestroy(event->handle)); + iree_allocator_free(host_allocator, event); + + IREE_TRACE_ZONE_END(z0); +} + +ze_event_handle_t iree_hal_level_zero_event_handle( + const iree_hal_event_t* base_event) { + return ((const iree_hal_level_zero_event_t*)base_event)->handle; +} + +static const iree_hal_event_vtable_t iree_hal_level_zero_event_vtable = { + .destroy = iree_hal_level_zero_event_destroy, +}; diff --git a/experimental/level_zero/level_zero_event.h b/experimental/level_zero/level_zero_event.h new file mode 100644 index 0000000000000..d5924dec82ef7 --- /dev/null +++ b/experimental/level_zero/level_zero_event.h @@ -0,0 +1,36 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_LEVEL_ZERO_EVENT_H_ +#define IREE_HAL_LEVEL_ZERO_EVENT_H_ + +#include "experimental/level_zero/context_wrapper.h" +#include "experimental/level_zero/level_zero_headers.h" +#include "experimental/level_zero/status_util.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a dummy event object. Object will be represented by level_zero +// command list nodes so nothing is created at creation time. When an event is +// signaled in the command buffer we will add the appropriate edges to enforce +// the right synchronization. +iree_status_t iree_hal_level_zero_event_create( + iree_hal_level_zero_context_wrapper_t* context_wrapper, + ze_event_pool_handle_t event_pool, iree_hal_event_t** out_event); + +// Returns Level Zero event handle. +ze_event_handle_t iree_hal_level_zero_event_handle( + const iree_hal_event_t* event); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LEVEL_ZERO_EVENT_H_ diff --git a/experimental/level_zero/level_zero_headers.h b/experimental/level_zero/level_zero_headers.h new file mode 100644 index 0000000000000..aef66c5b44ff7 --- /dev/null +++ b/experimental/level_zero/level_zero_headers.h @@ -0,0 +1,17 @@ + +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_LEVEL_ZERO_LEVEL_ZERO_HEADERS_H_ +#define IREE_HAL_LEVEL_ZERO_LEVEL_ZERO_HEADERS_H_ + +#if defined(IREE_PTR_SIZE_32) +#error 32-bit not supported on level zero +#endif // defined(IREE_PTR_SIZE_32) + +#include "ze_api.h" // IWYU pragma: export + +#endif // IREE_HAL_LEVEL_ZERO_LEVEL_ZERO_HEADERS_H_ diff --git a/experimental/level_zero/native_executable.c b/experimental/level_zero/native_executable.c new file mode 100644 index 0000000000000..4fdfac57be595 --- /dev/null +++ b/experimental/level_zero/native_executable.c @@ -0,0 +1,186 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/level_zero/native_executable.h" + +#include + +#include "experimental/level_zero/dynamic_symbols.h" +#include "experimental/level_zero/pipeline_layout.h" +#include "experimental/level_zero/status_util.h" +#include "iree/base/api.h" + +// flatcc schemas: +#include "iree/base/internal/flatcc/parsing.h" +#include "iree/schemas/level_zero_executable_def_reader.h" +#include "iree/schemas/level_zero_executable_def_verifier.h" + +typedef struct iree_hal_level_zero_native_executable_function_t { + ze_kernel_handle_t level_zero_function; + uint32_t block_size_x; + uint32_t block_size_y; + uint32_t block_size_z; +} iree_hal_level_zero_native_executable_function_t; + +typedef struct iree_hal_level_zero_native_executable_t { + iree_hal_resource_t resource; + iree_hal_level_zero_context_wrapper_t* context; + iree_hal_pipeline_layout_t** pipeline_layouts; + iree_host_size_t entry_count; + ze_module_handle_t module; + iree_hal_level_zero_native_executable_function_t entry_functions[]; +} iree_hal_level_zero_native_executable_t; + +static const iree_hal_executable_vtable_t + iree_hal_level_zero_native_executable_vtable; + +static iree_hal_level_zero_native_executable_t* +iree_hal_level_zero_native_executable_cast(iree_hal_executable_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_level_zero_native_executable_vtable); + return (iree_hal_level_zero_native_executable_t*)base_value; +} + +iree_status_t iree_hal_level_zero_native_executable_create( + iree_hal_level_zero_context_wrapper_t* context, + const iree_hal_executable_params_t* executable_params, + ze_device_handle_t level_zero_device, + iree_hal_executable_t** out_executable) { + IREE_ASSERT_ARGUMENT(context); + IREE_ASSERT_ARGUMENT(executable_params); + IREE_ASSERT_ARGUMENT(out_executable); + *out_executable = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_level_zero_native_executable_t* executable = NULL; + + iree_LEVEL_ZEROExecutableDef_table_t executable_def = + iree_LEVEL_ZEROExecutableDef_as_root( + executable_params->executable_data.data); + + // Create the kernel module. + flatbuffers_uint32_vec_t level_zero_image = + iree_LEVEL_ZEROExecutableDef_level_zero_image_get(executable_def); + flatbuffers_string_vec_t entry_points_vec = + iree_LEVEL_ZEROExecutableDef_entry_points_get(executable_def); + iree_LEVEL_ZEROBlockSizeDef_vec_t block_sizes_vec = + iree_LEVEL_ZEROExecutableDef_block_sizes_get(executable_def); + iree_host_size_t entry_count = flatbuffers_string_vec_len(entry_points_vec); + iree_host_size_t total_size = + sizeof(*executable) + + entry_count * sizeof(iree_hal_level_zero_native_executable_function_t) + + entry_count * sizeof(iree_hal_pipeline_layout_t*); + iree_status_t status = iree_allocator_malloc(context->host_allocator, + total_size, (void**)&executable); + executable->pipeline_layouts = + (void*)((char*)executable + sizeof(*executable) + + entry_count * + sizeof(iree_hal_level_zero_native_executable_function_t)); + ze_module_handle_t module = NULL; + ze_module_build_log_handle_t build_log; + if (iree_status_is_ok(status)) { + ze_module_desc_t module_desc = {}; + module_desc.stype = ZE_STRUCTURE_TYPE_MODULE_DESC; + module_desc.format = ZE_MODULE_FORMAT_IL_SPIRV; + iree_const_byte_span_t code = iree_make_const_byte_span( + level_zero_image, + flatbuffers_uint32_vec_len(level_zero_image) * sizeof(uint32_t)); + module_desc.pInputModule = (const uint8_t*)(code.data); + module_desc.inputSize = code.data_length; + module_desc.pBuildFlags = ""; + status = LEVEL_ZERO_RESULT_TO_STATUS( + context->syms, + zeModuleCreate(context->level_zero_context, level_zero_device, + &module_desc, &module, &build_log), + "zeModuleCreate"); + } + for (iree_host_size_t i = 0; i < entry_count; i++) { + if (iree_status_is_ok(status)) { + const char* entry_name = flatbuffers_string_vec_at(entry_points_vec, i); + ze_kernel_handle_t function = NULL; + ze_kernel_desc_t kernel_desc = {.stype = ZE_STRUCTURE_TYPE_KERNEL_DESC}; + // kernel_desc.pKernelName = "simple_mul_dispatch_0"; + kernel_desc.pKernelName = entry_name; + LEVEL_ZERO_RETURN_IF_ERROR( + context->syms, zeKernelCreate(module, &kernel_desc, &function), + "zeKernelCreate"); + executable->entry_functions[i].level_zero_function = function; + executable->entry_functions[i].block_size_x = block_sizes_vec[i].x; + executable->entry_functions[i].block_size_y = block_sizes_vec[i].y; + executable->entry_functions[i].block_size_z = block_sizes_vec[i].z; + executable->pipeline_layouts[i] = executable_params->pipeline_layouts[i]; + iree_hal_pipeline_layout_retain(executable_params->pipeline_layouts[i]); + } + } + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_level_zero_native_executable_vtable, + &executable->resource); + executable->module = module; + executable->context = context; + *out_executable = (iree_hal_executable_t*)executable; + } else { + // print log + size_t szLog = 0; + status = LEVEL_ZERO_RESULT_TO_STATUS( + context->syms, zeModuleBuildLogGetString(build_log, &szLog, NULL), + "zeModuleBuildLogGetString"); + char* stringLog = (char*)malloc(szLog); + status = LEVEL_ZERO_RESULT_TO_STATUS( + context->syms, zeModuleBuildLogGetString(build_log, &szLog, stringLog), + "zeModuleBuildLogGetString"); + status = LEVEL_ZERO_RESULT_TO_STATUS(context->syms, + zeModuleBuildLogDestroy(build_log), + "zeModuleBuildLogDestroy"); + iree_hal_executable_destroy((iree_hal_executable_t*)executable); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +ze_kernel_handle_t iree_hal_level_zero_native_executable_for_entry_point( + iree_hal_executable_t* base_executable, int32_t entry_point) { + iree_hal_level_zero_native_executable_t* executable = + iree_hal_level_zero_native_executable_cast(base_executable); + return executable->entry_functions[entry_point].level_zero_function; +} + +iree_status_t iree_hal_level_zero_native_executable_block_size( + iree_hal_executable_t* base_executable, int32_t entry_point, uint32_t* x, + uint32_t* y, uint32_t* z) { + iree_hal_level_zero_native_executable_t* executable = + iree_hal_level_zero_native_executable_cast(base_executable); + *x = executable->entry_functions[entry_point].block_size_x; + *y = executable->entry_functions[entry_point].block_size_y; + *z = executable->entry_functions[entry_point].block_size_z; + return iree_ok_status(); +} + +iree_hal_pipeline_layout_t* iree_hal_level_zero_executable_get_layout( + iree_hal_executable_t* base_executable, int32_t entry_point) { + iree_hal_level_zero_native_executable_t* executable = + iree_hal_level_zero_native_executable_cast(base_executable); + return executable->pipeline_layouts[entry_point]; +} + +static void iree_hal_level_zero_native_executable_destroy( + iree_hal_executable_t* base_executable) { + iree_hal_level_zero_native_executable_t* executable = + iree_hal_level_zero_native_executable_cast(base_executable); + iree_allocator_t host_allocator = executable->context->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < executable->entry_count; ++i) { + iree_hal_pipeline_layout_release(executable->pipeline_layouts[i]); + } + iree_allocator_free(host_allocator, executable); + + IREE_TRACE_ZONE_END(z0); +} + +static const iree_hal_executable_vtable_t + iree_hal_level_zero_native_executable_vtable = { + .destroy = iree_hal_level_zero_native_executable_destroy, +}; diff --git a/experimental/level_zero/native_executable.h b/experimental/level_zero/native_executable.h new file mode 100644 index 0000000000000..b02de6bc81bad --- /dev/null +++ b/experimental/level_zero/native_executable.h @@ -0,0 +1,45 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_LEVEL_ZERO_NATIVE_EXECUTABLE_H_ +#define IREE_HAL_LEVEL_ZERO_NATIVE_EXECUTABLE_H_ + +#include + +#include "experimental/level_zero/context_wrapper.h" +#include "experimental/level_zero/level_zero_headers.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates an executable from a SPV module. The module may contain several +// kernels that can be extracted along with the associated block size. +iree_status_t iree_hal_level_zero_native_executable_create( + iree_hal_level_zero_context_wrapper_t* context, + const iree_hal_executable_params_t* executable_params, + ze_device_handle_t level_zero_device, + iree_hal_executable_t** out_executable); + +ze_kernel_handle_t iree_hal_level_zero_native_executable_for_entry_point( + iree_hal_executable_t* executable, int32_t entry_point); + +// Return the block size of the given |entry_point| within the executable. +iree_status_t iree_hal_level_zero_native_executable_block_size( + iree_hal_executable_t* executable, int32_t entry_point, uint32_t* x, + uint32_t* y, uint32_t* z); + +/// Return the layout associated with the entry point. +iree_hal_pipeline_layout_t* iree_hal_level_zero_executable_get_layout( + iree_hal_executable_t* executable, int32_t entry_point); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LEVEL_ZERO_NATIVE_EXECUTABLE_H_ diff --git a/experimental/level_zero/nop_executable_cache.c b/experimental/level_zero/nop_executable_cache.c new file mode 100644 index 0000000000000..376ee17ab6b47 --- /dev/null +++ b/experimental/level_zero/nop_executable_cache.c @@ -0,0 +1,96 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/level_zero/nop_executable_cache.h" + +#include +#include + +#include "experimental/level_zero/native_executable.h" +#include "iree/base/api.h" + +typedef struct iree_hal_level_zero_nop_executable_cache_t { + iree_hal_resource_t resource; + ze_device_handle_t level_zero_device; + iree_hal_level_zero_context_wrapper_t* context; +} iree_hal_level_zero_nop_executable_cache_t; + +static const iree_hal_executable_cache_vtable_t + iree_hal_level_zero_nop_executable_cache_vtable; + +static iree_hal_level_zero_nop_executable_cache_t* +iree_hal_level_zero_nop_executable_cache_cast( + iree_hal_executable_cache_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_level_zero_nop_executable_cache_vtable); + return (iree_hal_level_zero_nop_executable_cache_t*)base_value; +} + +iree_status_t iree_hal_level_zero_nop_executable_cache_create( + iree_hal_level_zero_context_wrapper_t* context, + iree_string_view_t identifier, ze_device_handle_t level_zero_device, + iree_hal_executable_cache_t** out_executable_cache) { + IREE_ASSERT_ARGUMENT(out_executable_cache); + *out_executable_cache = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_level_zero_nop_executable_cache_t* executable_cache = NULL; + iree_status_t status = + iree_allocator_malloc(context->host_allocator, sizeof(*executable_cache), + (void**)&executable_cache); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize( + &iree_hal_level_zero_nop_executable_cache_vtable, + &executable_cache->resource); + executable_cache->context = context; + executable_cache->level_zero_device = level_zero_device; + + *out_executable_cache = (iree_hal_executable_cache_t*)executable_cache; + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_level_zero_nop_executable_cache_destroy( + iree_hal_executable_cache_t* base_executable_cache) { + iree_hal_level_zero_nop_executable_cache_t* executable_cache = + iree_hal_level_zero_nop_executable_cache_cast(base_executable_cache); + iree_allocator_t host_allocator = executable_cache->context->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(host_allocator, executable_cache); + + IREE_TRACE_ZONE_END(z0); +} + +static bool iree_hal_level_zero_nop_executable_cache_can_prepare_format( + iree_hal_executable_cache_t* base_executable_cache, + iree_hal_executable_caching_mode_t caching_mode, + iree_string_view_t executable_format) { + return iree_string_view_equal(executable_format, + iree_make_cstring_view("ZERO")); +} + +static iree_status_t +iree_hal_level_zero_nop_executable_cache_prepare_executable( + iree_hal_executable_cache_t* base_executable_cache, + const iree_hal_executable_params_t* executable_params, + iree_hal_executable_t** out_executable) { + iree_hal_level_zero_nop_executable_cache_t* executable_cache = + iree_hal_level_zero_nop_executable_cache_cast(base_executable_cache); + return iree_hal_level_zero_native_executable_create( + executable_cache->context, executable_params, + executable_cache->level_zero_device, out_executable); +} + +static const iree_hal_executable_cache_vtable_t + iree_hal_level_zero_nop_executable_cache_vtable = { + .destroy = iree_hal_level_zero_nop_executable_cache_destroy, + .can_prepare_format = + iree_hal_level_zero_nop_executable_cache_can_prepare_format, + .prepare_executable = + iree_hal_level_zero_nop_executable_cache_prepare_executable, +}; diff --git a/experimental/level_zero/nop_executable_cache.h b/experimental/level_zero/nop_executable_cache.h new file mode 100644 index 0000000000000..05e48cc07b2fb --- /dev/null +++ b/experimental/level_zero/nop_executable_cache.h @@ -0,0 +1,30 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_LEVEL_ZERO_NOP_EXECUTABLE_CACHE_H_ +#define IREE_HAL_LEVEL_ZERO_NOP_EXECUTABLE_CACHE_H_ + +#include "experimental/level_zero/context_wrapper.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a no-op executable cache that does not cache at all. +// This is useful to isolate pipeline caching behavior and verify compilation +// behavior. +iree_status_t iree_hal_level_zero_nop_executable_cache_create( + iree_hal_level_zero_context_wrapper_t* context, + iree_string_view_t identifier, ze_device_handle_t level_zero_device, + iree_hal_executable_cache_t** out_executable_cache); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LEVEL_ZERO_NOP_EXECUTABLE_CACHE_H_ diff --git a/experimental/level_zero/pipeline_layout.c b/experimental/level_zero/pipeline_layout.c new file mode 100644 index 0000000000000..f21af4bbe1f5b --- /dev/null +++ b/experimental/level_zero/pipeline_layout.c @@ -0,0 +1,211 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/level_zero/pipeline_layout.h" + +#include + +#include "iree/base/api.h" + +//===----------------------------------------------------------------------===// +// iree_hal_level_zero_descriptor_set_layout_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_level_zero_descriptor_set_layout_t { + iree_hal_resource_t resource; + iree_hal_level_zero_context_wrapper_t* context; + iree_host_size_t binding_count; +} iree_hal_level_zero_descriptor_set_layout_t; + +static const iree_hal_descriptor_set_layout_vtable_t + iree_hal_level_zero_descriptor_set_layout_vtable; + +static iree_hal_level_zero_descriptor_set_layout_t* +iree_hal_level_zero_descriptor_set_layout_cast( + iree_hal_descriptor_set_layout_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_level_zero_descriptor_set_layout_vtable); + return (iree_hal_level_zero_descriptor_set_layout_t*)base_value; +} + +iree_status_t iree_hal_level_zero_descriptor_set_layout_create( + iree_hal_level_zero_context_wrapper_t* context, + iree_hal_descriptor_set_layout_flags_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { + IREE_ASSERT_ARGUMENT(context); + IREE_ASSERT_ARGUMENT(!binding_count || bindings); + IREE_ASSERT_ARGUMENT(out_descriptor_set_layout); + *out_descriptor_set_layout = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_level_zero_descriptor_set_layout_t* descriptor_set_layout = NULL; + iree_status_t status = iree_allocator_malloc(context->host_allocator, + sizeof(*descriptor_set_layout), + (void**)&descriptor_set_layout); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize( + &iree_hal_level_zero_descriptor_set_layout_vtable, + &descriptor_set_layout->resource); + descriptor_set_layout->context = context; + descriptor_set_layout->binding_count = binding_count; + *out_descriptor_set_layout = + (iree_hal_descriptor_set_layout_t*)descriptor_set_layout; + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +iree_host_size_t iree_hal_level_zero_descriptor_set_layout_binding_count( + iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) { + iree_hal_level_zero_descriptor_set_layout_t* descriptor_set_layout = + iree_hal_level_zero_descriptor_set_layout_cast( + base_descriptor_set_layout); + return descriptor_set_layout->binding_count; +} + +static void iree_hal_level_zero_descriptor_set_layout_destroy( + iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) { + iree_hal_level_zero_descriptor_set_layout_t* descriptor_set_layout = + iree_hal_level_zero_descriptor_set_layout_cast( + base_descriptor_set_layout); + iree_allocator_t host_allocator = + descriptor_set_layout->context->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(host_allocator, descriptor_set_layout); + + IREE_TRACE_ZONE_END(z0); +} + +static const iree_hal_descriptor_set_layout_vtable_t + iree_hal_level_zero_descriptor_set_layout_vtable = { + .destroy = iree_hal_level_zero_descriptor_set_layout_destroy, +}; + +//===----------------------------------------------------------------------===// +// iree_hal_level_zero_pipeline_layout_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_level_zero_pipeline_layout_t { + iree_hal_resource_t resource; + iree_hal_level_zero_context_wrapper_t* context; + iree_host_size_t push_constant_base_index; + iree_host_size_t push_constant_count; + iree_host_size_t set_layout_count; + iree_hal_descriptor_set_layout_t* set_layouts[]; +} iree_hal_level_zero_pipeline_layout_t; + +static const iree_hal_pipeline_layout_vtable_t + iree_hal_level_zero_pipeline_layout_vtable; + +static iree_hal_level_zero_pipeline_layout_t* +iree_hal_level_zero_pipeline_layout_cast( + iree_hal_pipeline_layout_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_level_zero_pipeline_layout_vtable); + return (iree_hal_level_zero_pipeline_layout_t*)base_value; +} + +static void iree_hal_level_zero_pipeline_layout_destroy( + iree_hal_pipeline_layout_t* base_pipeline_layout) { + iree_hal_level_zero_pipeline_layout_t* pipeline_layout = + iree_hal_level_zero_pipeline_layout_cast(base_pipeline_layout); + iree_allocator_t host_allocator = pipeline_layout->context->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < pipeline_layout->set_layout_count; ++i) { + iree_hal_descriptor_set_layout_release(pipeline_layout->set_layouts[i]); + } + iree_allocator_free(host_allocator, pipeline_layout); + + IREE_TRACE_ZONE_END(z0); +} + +iree_status_t iree_hal_level_zero_pipeline_layout_create( + iree_hal_level_zero_context_wrapper_t* context, + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t* const* set_layouts, + iree_host_size_t push_constant_count, + iree_hal_pipeline_layout_t** out_pipeline_layout) { + IREE_ASSERT_ARGUMENT(context); + IREE_ASSERT_ARGUMENT(!set_layout_count || set_layouts); + IREE_ASSERT_ARGUMENT(out_pipeline_layout); + *out_pipeline_layout = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + if (push_constant_count > IREE_HAL_LEVEL_ZERO_MAX_PUSH_CONSTANT_COUNT) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "push constant count %zu over the limit of %d", + push_constant_count, + IREE_HAL_LEVEL_ZERO_MAX_PUSH_CONSTANT_COUNT); + } + + // Currently the executable layout doesn't do anything. + // TODO: Handle creating the argument layout at that time hadling both push + // constant and buffers. + iree_hal_level_zero_pipeline_layout_t* pipeline_layout = NULL; + iree_host_size_t total_size = + sizeof(*pipeline_layout) + + set_layout_count * sizeof(*pipeline_layout->set_layouts); + iree_status_t status = iree_allocator_malloc( + context->host_allocator, total_size, (void**)&pipeline_layout); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_level_zero_pipeline_layout_vtable, + &pipeline_layout->resource); + pipeline_layout->context = context; + pipeline_layout->set_layout_count = set_layout_count; + iree_host_size_t binding_number = 0; + for (iree_host_size_t i = 0; i < set_layout_count; ++i) { + pipeline_layout->set_layouts[i] = set_layouts[i]; + iree_hal_descriptor_set_layout_retain(set_layouts[i]); + binding_number += iree_hal_level_zero_descriptor_set_layout_binding_count( + set_layouts[i]); + } + pipeline_layout->push_constant_base_index = binding_number; + pipeline_layout->push_constant_count = push_constant_count; + *out_pipeline_layout = (iree_hal_pipeline_layout_t*)pipeline_layout; + IREE_HAL_ASSERT_TYPE(*out_pipeline_layout, + &iree_hal_level_zero_pipeline_layout_vtable); + IREE_HAL_ASSERT_TYPE(pipeline_layout, + &iree_hal_level_zero_pipeline_layout_vtable); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +iree_host_size_t iree_hal_level_zero_base_binding_index( + iree_hal_pipeline_layout_t* base_pipeline_layout, uint32_t set) { + iree_hal_level_zero_pipeline_layout_t* pipeline_layout = + iree_hal_level_zero_pipeline_layout_cast(base_pipeline_layout); + iree_host_size_t base_binding = 0; + for (iree_host_size_t i = 0; i < set; ++i) { + iree_host_size_t binding_count = + iree_hal_level_zero_descriptor_set_layout_binding_count( + pipeline_layout->set_layouts[i]); + base_binding += binding_count; + } + return base_binding; +} + +iree_host_size_t iree_hal_level_zero_push_constant_index( + iree_hal_pipeline_layout_t* base_pipeline_layout) { + iree_hal_level_zero_pipeline_layout_t* pipeline_layout = + iree_hal_level_zero_pipeline_layout_cast(base_pipeline_layout); + return pipeline_layout->push_constant_base_index; +} + +iree_host_size_t iree_hal_level_zero_pipeline_layout_num_constants( + iree_hal_pipeline_layout_t* base_pipeline_layout) { + iree_hal_level_zero_pipeline_layout_t* pipeline_layout = + iree_hal_level_zero_pipeline_layout_cast(base_pipeline_layout); + return pipeline_layout->push_constant_count; +} + +static const iree_hal_pipeline_layout_vtable_t + iree_hal_level_zero_pipeline_layout_vtable = { + .destroy = iree_hal_level_zero_pipeline_layout_destroy, +}; diff --git a/experimental/level_zero/pipeline_layout.h b/experimental/level_zero/pipeline_layout.h new file mode 100644 index 0000000000000..25b4abfb6d979 --- /dev/null +++ b/experimental/level_zero/pipeline_layout.h @@ -0,0 +1,63 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_LEVEL_ZERO_PIPELINE_LAYOUT_H_ +#define IREE_HAL_LEVEL_ZERO_PIPELINE_LAYOUT_H_ + +#include "experimental/level_zero/context_wrapper.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#define IREE_HAL_LEVEL_ZERO_MAX_PUSH_CONSTANT_COUNT 64 + +//===----------------------------------------------------------------------===// +// iree_hal_level_zero_descriptor_set_layout_t +//===----------------------------------------------------------------------===// + +iree_status_t iree_hal_level_zero_descriptor_set_layout_create( + iree_hal_level_zero_context_wrapper_t* context, + iree_hal_descriptor_set_layout_flags_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout); + +// Return the binding count for the given descriptor set layout. +iree_host_size_t iree_hal_level_zero_descriptor_set_layout_binding_count( + iree_hal_descriptor_set_layout_t* descriptor_set_layout); + +//===----------------------------------------------------------------------===// +// iree_hal_level_zero_pipeline_layout_t +//===----------------------------------------------------------------------===// + +// Creates the kernel arguments. +iree_status_t iree_hal_level_zero_pipeline_layout_create( + iree_hal_level_zero_context_wrapper_t* context, + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t* const* set_layouts, + iree_host_size_t push_constant_count, + iree_hal_pipeline_layout_t** out_pipeline_layout); + +// Return the base binding index for the given set. +iree_host_size_t iree_hal_level_zero_base_binding_index( + iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set); + +// Return the base index for push constant data. +iree_host_size_t iree_hal_level_zero_push_constant_index( + iree_hal_pipeline_layout_t* base_pipeline_layout); + +// Return the number of constants in the executable layout. +iree_host_size_t iree_hal_level_zero_pipeline_layout_num_constants( + iree_hal_pipeline_layout_t* base_pipeline_layout); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LEVEL_ZERO_PIPELINE_LAYOUT_H_ diff --git a/experimental/level_zero/registration/CMakeLists.txt b/experimental/level_zero/registration/CMakeLists.txt new file mode 100644 index 0000000000000..1336f95640a10 --- /dev/null +++ b/experimental/level_zero/registration/CMakeLists.txt @@ -0,0 +1,21 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_cc_library( + NAME + registration + HDRS + "driver_module.h" + SRCS + "driver_module.c" + DEPS + iree::base + iree::experimental::level_zero + iree::hal + DEFINES + "IREE_HAVE_HAL_EXPERIMENTAL_LEVEL_ZERO_DRIVER_MODULE=1" + PUBLIC +) diff --git a/experimental/level_zero/registration/driver_module.c b/experimental/level_zero/registration/driver_module.c new file mode 100644 index 0000000000000..79bfb52b0da05 --- /dev/null +++ b/experimental/level_zero/registration/driver_module.c @@ -0,0 +1,55 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/level_zero/registration/driver_module.h" + +#include +#include + +#include "experimental/level_zero/api.h" +#include "iree/base/api.h" + +static iree_status_t iree_hal_level_zero_driver_factory_enumerate( + void *self, iree_host_size_t *out_driver_info_count, + const iree_hal_driver_info_t **out_driver_infos) { + // NOTE: we could query supported LEVEL_ZERO versions or featuresets here. + static const iree_hal_driver_info_t driver_infos[1] = {{ + .driver_name = iree_string_view_literal("level_zero"), + .full_name = iree_string_view_literal("LEVEL_ZERO (dynamic)"), + }}; + *out_driver_info_count = IREE_ARRAYSIZE(driver_infos); + *out_driver_infos = driver_infos; + return iree_ok_status(); +} + +static iree_status_t iree_hal_level_zero_driver_factory_try_create( + void *self, iree_string_view_t driver_name, iree_allocator_t host_allocator, + iree_hal_driver_t **out_driver) { + IREE_ASSERT_ARGUMENT(out_driver); + *out_driver = NULL; + if (!iree_string_view_equal(driver_name, IREE_SV("level_zero"))) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "no driver '%.*s' is provided by this factory", + (int)driver_name.size, driver_name.data); + } + IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_level_zero_driver_options_t driver_options; + iree_hal_level_zero_driver_options_initialize(&driver_options); + iree_status_t status = iree_hal_level_zero_driver_create( + driver_name, &driver_options, host_allocator, out_driver); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t iree_hal_level_zero_driver_module_register( + iree_hal_driver_registry_t *registry) { + static const iree_hal_driver_factory_t factory = { + .self = NULL, + .enumerate = iree_hal_level_zero_driver_factory_enumerate, + .try_create = iree_hal_level_zero_driver_factory_try_create, + }; + return iree_hal_driver_registry_register_factory(registry, &factory); +} diff --git a/experimental/level_zero/registration/driver_module.h b/experimental/level_zero/registration/driver_module.h new file mode 100644 index 0000000000000..b7dfc0026e668 --- /dev/null +++ b/experimental/level_zero/registration/driver_module.h @@ -0,0 +1,24 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_LEVEL_ZERO_REGISTRATION_DRIVER_MODULE_H_ +#define IREE_HAL_LEVEL_ZERO_REGISTRATION_DRIVER_MODULE_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +IREE_API_EXPORT iree_status_t iree_hal_level_zero_driver_module_register( + iree_hal_driver_registry_t *registry); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LEVEL_ZERO_REGISTRATION_DRIVER_MODULE_H_ diff --git a/experimental/level_zero/status_util.c b/experimental/level_zero/status_util.c new file mode 100644 index 0000000000000..2d7cbddd6e25a --- /dev/null +++ b/experimental/level_zero/status_util.c @@ -0,0 +1,242 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/level_zero/status_util.h" + +#include + +#include "experimental/level_zero/dynamic_symbols.h" + +iree_status_t iree_hal_level_zero_result_to_status( + iree_hal_level_zero_dynamic_symbols_t *syms, ze_result_t result, + const char *file, uint32_t line) { + switch (result) { + case ZE_RESULT_SUCCESS: + // Command successfully completed. + return iree_ok_status(); + case ZE_RESULT_NOT_READY: + // A fence or query has not yet completed. + return iree_ok_status(); + case ZE_RESULT_WARNING_DROPPED_DATA: + // [Tools] data may have been dropped + return iree_ok_status(); + case ZE_RESULT_ERROR_DEVICE_LOST: + // Lost device, may be due to device hung, reset, removed, or driver + // update occurred. + return iree_make_status_with_location(file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_DEVICE_LOST"); + case ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY: + // [Core] insufficient host memory to satisfy call + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY"); + case ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY: + // [Core] insufficient device memory to satisfy call + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY"); + case ZE_RESULT_ERROR_MODULE_BUILD_FAILURE: + // [Core] error occurred when building module, see build log for details + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_MODULE_BUILD_FAILURE"); + case ZE_RESULT_ERROR_MODULE_LINK_FAILURE: + // [Core] error occurred when linking modules, see build log for details + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_MODULE_LINK_FAILURE"); + case ZE_RESULT_ERROR_DEVICE_REQUIRES_RESET: + // [Core] device requires a reset + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_DEVICE_REQUIRES_RESET"); + case ZE_RESULT_ERROR_DEVICE_IN_LOW_POWER_STATE: + // [Core] device currently in low power state + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_DEVICE_IN_LOW_POWER_STATE"); + case ZE_RESULT_EXP_ERROR_DEVICE_IS_NOT_VERTEX: + // [Core, Expoerimental] device is not represented by a fabric vertex + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_EXP_ERROR_DEVICE_IS_NOT_VERTEX"); + case ZE_RESULT_EXP_ERROR_VERTEX_IS_NOT_DEVICE: + // [Core, Experimental] fabric vertex does not represent a device + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_EXP_ERROR_VERTEX_IS_NOT_DEVICE"); + case ZE_RESULT_EXP_ERROR_REMOTE_DEVICE: + // [Core, Expoerimental] fabric vertex represents a remote device or + // subdevice + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_EXP_ERROR_REMOTE_DEVICE"); + case ZE_RESULT_ERROR_INSUFFICIENT_PERMISSIONS: + // [Sysman] access denied due to permission level + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INSUFFICIENT_PERMISSIONS"); + case ZE_RESULT_ERROR_NOT_AVAILABLE: + // [Sysman] resource already in use and simultaneous access not allowed or + // resource was removed + return iree_make_status_with_location(file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_NOT_AVAILABLE"); + case ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE: + // [Tools] external required dependency is unavailable or missing + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE"); + case ZE_RESULT_ERROR_UNINITIALIZED: + // [Validation] driver is not initialized + return iree_make_status_with_location(file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_UNINITIALIZED"); + case ZE_RESULT_ERROR_UNSUPPORTED_VERSION: + // [Validation] generic error code for unsupported versions + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_UNSUPPORTED_VERSION"); + case ZE_RESULT_ERROR_UNSUPPORTED_FEATURE: + // [Validation] generic error code for unsupported features + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_UNSUPPORTED_FEATURE"); + case ZE_RESULT_ERROR_INVALID_ARGUMENT: + // [Validation] generic error code for invalid arguments + return iree_make_status_with_location(file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_ARGUMENT"); + case ZE_RESULT_ERROR_INVALID_NULL_HANDLE: + // [Validation] handle argument is not valid + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_NULL_HANDLE"); + case ZE_RESULT_ERROR_HANDLE_OBJECT_IN_USE: + // [Validation] object pointed to by handle still in-use by device + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_HANDLE_OBJECT_IN_USE"); + case ZE_RESULT_ERROR_INVALID_NULL_POINTER: + // [Validation] pointer argument may not be NULL + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_NULL_POINTER"); + case ZE_RESULT_ERROR_INVALID_SIZE: + // [Validation] size argument is invalid (e.g., must not be zero) + return iree_make_status_with_location(file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_SIZE"); + case ZE_RESULT_ERROR_UNSUPPORTED_SIZE: + // [Validation] size argument is not supported by the device (e.g., too + // large) + return iree_make_status_with_location(file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_UNSUPPORTED_SIZE"); + case ZE_RESULT_ERROR_UNSUPPORTED_ALIGNMENT: + // [Validation] alignment argument is not supported by the device (e.g., + // too small) + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_UNSUPPORTED_ALIGNMENT"); + case ZE_RESULT_ERROR_INVALID_SYNCHRONIZATION_OBJECT: + // [Validation] synchronization object in invalid state + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_SYNCHRONIZATION_OBJECT"); + case ZE_RESULT_ERROR_INVALID_ENUMERATION: + // [Validation] enumerator argument is not valid + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_ENUMERATION"); + case ZE_RESULT_ERROR_UNSUPPORTED_ENUMERATION: + // [Validation] enumerator argument is not supported + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_UNSUPPORTED_ENUMERATION"); + case ZE_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT: + // [Validation] image format is not supported by the device + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT"); + case ZE_RESULT_ERROR_INVALID_NATIVE_BINARY: + // [Validation] native binary is not supported by the device + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_NATIVE_BINARY"); + case ZE_RESULT_ERROR_INVALID_GLOBAL_NAME: + // [Validation] global variable is not found in the module + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_GLOBAL_NAME"); + case ZE_RESULT_ERROR_INVALID_KERNEL_NAME: + // [Validation] kernel name is not found in the module + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_KERNEL_NAME"); + case ZE_RESULT_ERROR_INVALID_FUNCTION_NAME: + // [Validation] function name is not found in the module + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_FUNCTION_NAME"); + case ZE_RESULT_ERROR_INVALID_GROUP_SIZE_DIMENSION: + // [Validation] group size dimension is not valid for the kernel or device + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_GROUP_SIZE_DIMENSION"); + case ZE_RESULT_ERROR_INVALID_GLOBAL_WIDTH_DIMENSION: + // [Validation] global width dimension is not valid for the kernel or + // device + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_GLOBAL_WIDTH_DIMENSION"); + case ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX: + // [Validation] kernel argument index is not valid for kernel + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX"); + case ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE: + // [Validation] kernel argument size does not match kernel + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE"); + case ZE_RESULT_ERROR_INVALID_KERNEL_ATTRIBUTE_VALUE: + // [Validation] value of kernel attribute is not valid for the kernel or + // device + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_KERNEL_ATTRIBUTE_VALUE"); + case ZE_RESULT_ERROR_INVALID_MODULE_UNLINKED: + // [Validation] module with imports needs to be linked before kernels can + // be created from it + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_MODULE_UNLINKED"); + case ZE_RESULT_ERROR_INVALID_COMMAND_LIST_TYPE: + // [Validation] command list type does not match command queue type + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_INVALID_COMMAND_LIST_TYPE"); + case ZE_RESULT_ERROR_OVERLAPPING_REGIONS: + // [Validation] copy operations do not support overlapping regions of + // memory + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_OVERLAPPING_REGIONS"); + case ZE_RESULT_WARNING_ACTION_REQUIRED: + // [Sysman] an action is required to complete the desired operation + return iree_make_status_with_location( + file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_WARNING_ACTION_REQUIRED"); + case ZE_RESULT_ERROR_UNKNOWN: + // [Core] unknown or internal error + return iree_make_status_with_location(file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_ERROR_UNKNOWN"); + case ZE_RESULT_FORCE_UINT32: + // [Core] unknown or internal error + return iree_make_status_with_location(file, line, IREE_STATUS_INTERNAL, + "ZE_RESULT_FORCE_UINT32"); + default: + // [Core] unknown or internal error + return iree_make_status_with_location(file, line, IREE_STATUS_INTERNAL, + "Unknown error found"); + } +} diff --git a/experimental/level_zero/status_util.h b/experimental/level_zero/status_util.h new file mode 100644 index 0000000000000..b57ec94230f80 --- /dev/null +++ b/experimental/level_zero/status_util.h @@ -0,0 +1,63 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_LEVEL_ZERO_STATUS_UTIL_H_ +#define IREE_HAL_LEVEL_ZERO_STATUS_UTIL_H_ + +#include + +#include "experimental/level_zero/dynamic_symbols.h" +#include "iree/base/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Converts a ze_result_t to an iree_status_t. +// +// Usage: +// iree_status_t status = LEVEL_ZERO_RESULT_TO_STATUS(levelZeroDoThing(...)); +#define LEVEL_ZERO_RESULT_TO_STATUS(syms, expr, ...) \ + iree_hal_level_zero_result_to_status((syms), ((syms)->expr), __FILE__, \ + __LINE__) + +// IREE_RETURN_IF_ERROR but implicitly converts the ze_result_t return value to +// a Status. +// +// Usage: +// LEVEL_ZERO_RETURN_IF_ERROR(levelZeroDoThing(...), "message"); +#define LEVEL_ZERO_RETURN_IF_ERROR(syms, expr, ...) \ + IREE_RETURN_IF_ERROR(iree_hal_level_zero_result_to_status( \ + (syms), ((syms)->expr), __FILE__, __LINE__), \ + __VA_ARGS__) + +// IREE_IGNORE_ERROR but implicitly converts the ze_result_t return value to a +// Status. +// +// Usage: +// LEVEL_ZERO_IGNORE_ERROR(levelZeroDoThing(...)); +#define LEVEL_ZERO_IGNORE_ERROR(syms, expr) \ + IREE_IGNORE_ERROR(iree_hal_level_zero_result_to_status( \ + (syms), ((syms)->expr), __FILE__, __LINE__)) + +// Converts a ze_result_t to a Status object. +iree_status_t iree_hal_level_zero_result_to_status( + iree_hal_level_zero_dynamic_symbols_t* syms, ze_result_t result, + const char* file, uint32_t line); + +#define IREE_LEVEL_ZERO_TRY(...) \ + do { \ + status = __VA_ARGS__; \ + if (!iree_status_is_ok(status)) { \ + goto cleanup; \ + } \ + } while (0) + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LEVEL_ZERO_STATUS_UTIL_H_ diff --git a/experimental/level_zero/test/CMakeLists.txt b/experimental/level_zero/test/CMakeLists.txt new file mode 100644 index 0000000000000..8b2921a36f407 --- /dev/null +++ b/experimental/level_zero/test/CMakeLists.txt @@ -0,0 +1,19 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +iree_cc_test( + NAME + level_zero_test + SRCS + level_zero_test.cc + DEPS + iree::runtime + iree::testing::gtest + iree::testing::gtest_main + PUBLIC +) diff --git a/experimental/level_zero/test/level_zero_test.cc b/experimental/level_zero/test/level_zero_test.cc new file mode 100644 index 0000000000000..953a1cbf873ac --- /dev/null +++ b/experimental/level_zero/test/level_zero_test.cc @@ -0,0 +1,86 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include +#include +#include + +template +std::unique_ptr make_unique(T* p, Deleter d) { + return std::unique_ptr(p, d); +} + +TEST(LevelZeroTest, UUID) { + iree_runtime_instance_options_t instance_options; + iree_runtime_instance_options_initialize(&instance_options); + iree_runtime_instance_options_use_all_available_drivers(&instance_options); + + static const char* driver_name = "level_zero"; + + iree_allocator_t host_allocator = iree_allocator_system(); + + // Make instance. + iree_runtime_instance_t* instance = nullptr; + ASSERT_EQ(iree_runtime_instance_create(&instance_options, + iree_allocator_system(), &instance), + iree_ok_status()); + auto instance_deleter = make_unique( + instance, + [](iree_runtime_instance_t* p) { iree_runtime_instance_release(p); }); + + // Make Level Zero driver. + iree_hal_driver_registry_t* driver_registry = + iree_runtime_instance_driver_registry(instance); + iree_hal_driver_t* driver = nullptr; + ASSERT_EQ(iree_hal_driver_registry_try_create( + driver_registry, iree_make_cstring_view(driver_name), + host_allocator, &driver), + iree_ok_status()); + auto driver_deleter = make_unique( + driver, [](iree_hal_driver_t* p) { iree_hal_driver_release(p); }); + + // Get list of available devices. + iree_hal_device_info_t* device_infos = nullptr; + iree_host_size_t device_infos_count = 0; + ASSERT_EQ(iree_hal_driver_query_available_devices( + driver, host_allocator, &device_infos_count, &device_infos), + iree_ok_status()); + auto device_infos_deleter = make_unique( + device_infos, [host_allocator](iree_hal_device_info_t* p) { + iree_allocator_free(host_allocator, p); + }); + ASSERT_GT(device_infos_count, 0); + + // Create a valid device from URI. + std::stringstream device_uri; + device_uri << driver_name << "://"; + device_uri << std::string(device_infos[0].path.data, + device_infos[0].path.size); + std::string device_uri_str = device_uri.str(); + iree_hal_device_t* device = nullptr; + ASSERT_EQ(iree_hal_driver_create_device_by_uri( + driver, iree_make_cstring_view(device_uri_str.c_str()), + host_allocator, &device), + iree_ok_status()); + auto device_deleter = make_unique( + device, [](iree_hal_device_t* p) { iree_hal_device_release(p); }); + + // Try create an invalid device from URI. + std::stringstream invalid_device_uri; + invalid_device_uri << driver_name + << "://4e5a272e-66a7-11ed-9342-4f1f581f812c"; + std::string invalid_device_uri_str = invalid_device_uri.str(); + iree_hal_device_t* invalid_device = nullptr; + ASSERT_NE(iree_hal_driver_create_device_by_uri( + driver, iree_make_cstring_view(invalid_device_uri_str.c_str()), + host_allocator, &invalid_device), + iree_ok_status()); + auto invalid_device_deleter = make_unique( + invalid_device, [](iree_hal_device_t* p) { iree_hal_device_release(p); }); +} diff --git a/runtime/src/iree/schemas/CMakeLists.txt b/runtime/src/iree/schemas/CMakeLists.txt index ef93a2bb10e5a..43b4cf51290ca 100644 --- a/runtime/src/iree/schemas/CMakeLists.txt +++ b/runtime/src/iree/schemas/CMakeLists.txt @@ -99,4 +99,16 @@ iree_cc_library( PUBLIC ) +flatbuffer_c_library( + NAME + level_zero_executable_def_c_fbs + SRCS + "level_zero_executable_def.fbs" + FLATCC_ARGS + "--reader" + "--builder" + "--verifier" + "--json" + PUBLIC +) ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/runtime/src/iree/schemas/level_zero_executable_def.fbs b/runtime/src/iree/schemas/level_zero_executable_def.fbs new file mode 100644 index 0000000000000..c2b5f968f7539 --- /dev/null +++ b/runtime/src/iree/schemas/level_zero_executable_def.fbs @@ -0,0 +1,33 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +namespace iree; + +// 'LEVEL_ZERO Executable'. +file_identifier "ZERO"; +file_extension "zero"; + +// A struct for the kernel block size along each dimensions. +struct LEVEL_ZEROBlockSizeDef { + x:uint32; + y:uint32; + z:uint32; +} + +table LEVEL_ZEROExecutableDef { + // A map of entry point ordinals to string names as used in the shader + // library. + entry_points:[string]; + + // Block sizes for each entry point. + // + block_sizes:[LEVEL_ZEROBlockSizeDef]; + + // HSACO string of the module. + level_zero_image:[uint32]; +} + +root_type LEVEL_ZEROExecutableDef; diff --git a/tests/e2e/stablehlo_ops/CMakeLists.txt b/tests/e2e/stablehlo_ops/CMakeLists.txt index c13f62e417a07..304ea56d3a003 100644 --- a/tests/e2e/stablehlo_ops/CMakeLists.txt +++ b/tests/e2e/stablehlo_ops/CMakeLists.txt @@ -221,6 +221,75 @@ iree_check_single_backend_test_suite( "--iree-input-type=stablehlo" ) +iree_check_single_backend_test_suite( + NAME + check_vulkan-spirv_opencl + SRCS + "abs.mlir" + "add.mlir" + "batch_norm_inference.mlir" + "bitcast_convert.mlir" + "broadcast.mlir" + "broadcast_add.mlir" + "broadcast_in_dim.mlir" + "clamp.mlir" + "compare.mlir" + "complex.mlir" + "concatenate.mlir" + "constant.mlir" + "convert.mlir" + "convolution.mlir" + "cosine.mlir" + "divide.mlir" + "dot.mlir" + "dot_bf16.mlir" + "dot_general.mlir" + "dynamic_slice.mlir" + "dynamic_update_slice.mlir" + "exponential.mlir" + "exponential_minus_one.mlir" + "finite.mlir" + "floor.mlir" + "gather.mlir" + "iota.mlir" + "log.mlir" + "log_plus_one.mlir" + "maximum.mlir" + "minimum.mlir" + "multiply.mlir" + "negate.mlir" + "pad.mlir" + "philox.mlir" + "pow.mlir" + "reduce.mlir" + "reduce_window.mlir" + "remainder.mlir" + "reshape.mlir" + "rng_normal.mlir" + "rng_uniform.mlir" + "round.mlir" + "rsqrt.mlir" + "scatter.mlir" + "scatter_dynamic.mlir" + "select.mlir" + "sine.mlir" + "slice.mlir" + "sort.mlir" + "sqrt.mlir" + "subtract.mlir" + "tanh.mlir" + "three_fry.mlir" + "torch_index_select.mlir" + "transpose.mlir" + "while.mlir" + TARGET_BACKEND + "opencl-spirv" + DRIVER + "level_zero" + COMPILER_FLAGS + "--iree-input-type=stablehlo" +) + iree_check_single_backend_test_suite( NAME check_llvm-cpu-host_local-task diff --git a/tests/e2e/tosa_ops/CMakeLists.txt b/tests/e2e/tosa_ops/CMakeLists.txt index a4a0d7ea99ab4..caf115532b1cc 100644 --- a/tests/e2e/tosa_ops/CMakeLists.txt +++ b/tests/e2e/tosa_ops/CMakeLists.txt @@ -221,6 +221,56 @@ iree_check_single_backend_test_suite( "--iree-input-type=tosa" ) +iree_check_single_backend_test_suite( + NAME + check_level_zero-spirv_opencl + SRCS + "abs.mlir" + "add.mlir" + "arithmetic_right_shift.mlir" + "bitwise_and.mlir" + "bitwise_or.mlir" + "bitwise_xor.mlir" + "ceil.mlir" + "clamp.mlir" + "const.mlir" + "equal.mlir" + "exp.mlir" + "floor.mlir" + "fully_connected.mlir" + "gather.mlir" + "greater.mlir" + "greater_equal.mlir" + "if.mlir" + "log.mlir" + "logical_left_shift.mlir" + "logical_right_shift.mlir" + "matmul.mlir" + "max_pool.mlir" + "maximum.mlir" + "minimum.mlir" + "mul.mlir" + "negate.mlir" + "pad.mlir" + "reciprocal.mlir" + "reduce.mlir" + "reshape.mlir" + "rsqrt.mlir" + "select.mlir" + "sigmoid.mlir" + "sub.mlir" + "table.mlir" + "tanh.mlir" + "transpose.mlir" + "while.mlir" + TARGET_BACKEND + "opencl-spirv" + DRIVER + "level_zero" + COMPILER_FLAGS + "--iree-input-type=tosa" +) + ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### iree_check_single_backend_test_suite( diff --git a/third_party/level-zero b/third_party/level-zero new file mode 160000 index 0000000000000..474188ae004a5 --- /dev/null +++ b/third_party/level-zero @@ -0,0 +1 @@ +Subproject commit 474188ae004a5c76953a829477997bc341e70d48 From 8ca6db66efc14c1c74ea556b5718c5ebba1d7439 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 10 Aug 2023 11:44:55 -0700 Subject: [PATCH 21/44] [Distributed] Rudimentary distributed Python API (#64) * Add rudimentary non-production distributed Python API * Distributed execution validation Add functionality that validates distributed StableHLO is producing the same results as non-distributed. * Add execution time measurement * Distributed Python API: add call_count to run_ranks * Add setup script for distributed Python API * Add JAX to install setup --------- Co-authored-by: Boian Petkantchin --- runtime/bindings/python/CMakeLists.txt | 5 + .../bindings/python/iree/runtime/__init__.py | 1 + .../iree/runtime/distributed/__init__.py | 9 + .../iree/runtime/distributed/distributed.py | 86 +++++++ .../iree/runtime/distributed/run_rank.py | 132 +++++++++++ .../python/iree/runtime/distributed/setup.sh | 15 ++ .../distributed/sharding_pass_validation.py | 210 ++++++++++++++++++ .../python/iree/runtime/distributed/utils.py | 26 +++ 8 files changed, 484 insertions(+) create mode 100644 runtime/bindings/python/iree/runtime/distributed/__init__.py create mode 100644 runtime/bindings/python/iree/runtime/distributed/distributed.py create mode 100644 runtime/bindings/python/iree/runtime/distributed/run_rank.py create mode 100644 runtime/bindings/python/iree/runtime/distributed/setup.sh create mode 100644 runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py create mode 100644 runtime/bindings/python/iree/runtime/distributed/utils.py diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt index aba101dc52fad..5d6c29d822250 100644 --- a/runtime/bindings/python/CMakeLists.txt +++ b/runtime/bindings/python/CMakeLists.txt @@ -130,6 +130,11 @@ iree_py_library( "iree/_runtime/scripts/iree_run_trace/__main__.py" "iree/_runtime/scripts/iree_run_module/__main__.py" "iree/_runtime/scripts/iree_tracy_capture/__main__.py" + "iree/runtime/distributed/__init__.py" + "iree/runtime/distributed/distributed.py" + "iree/runtime/distributed/run_rank.py" + "iree/runtime/distributed/sharding_pass_validation.py" + "iree/runtime/distributed/utils.py" PYEXT_DEPS iree_runtime_bindings_python_PyExtRt ) diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py index 140de1010aa5a..b398cdda65aeb 100644 --- a/runtime/bindings/python/iree/runtime/__init__.py +++ b/runtime/bindings/python/iree/runtime/__init__.py @@ -57,4 +57,5 @@ from .function import * from .tracing import * +from . import distributed from . import flags diff --git a/runtime/bindings/python/iree/runtime/distributed/__init__.py b/runtime/bindings/python/iree/runtime/distributed/__init__.py new file mode 100644 index 0000000000000..86ee5db110ccb --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .distributed import prepare_shards_io_files, run_ranks + +__all__ = ["prepare_shards_io_files", "run_ranks"] diff --git a/runtime/bindings/python/iree/runtime/distributed/distributed.py b/runtime/bindings/python/iree/runtime/distributed/distributed.py new file mode 100644 index 0000000000000..258e517b2cf23 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/distributed.py @@ -0,0 +1,86 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import iree.compiler +import sys +import iree.runtime +from iree.runtime.array_interop import DeviceArray +import os +from numpy.typing import ArrayLike +from typing import List, Tuple +import tempfile +import subprocess +from . import utils + + +def prepare_shards_io_files( + inputs: List[List[ArrayLike]], out_dir: str +) -> Tuple[List[str], List[str]]: + input_filepaths = [] + output_filepaths = [] + for i in range(len(inputs)): + input_filepath = os.path.join(out_dir, f"shard_{i}", "input.npy") + input_filepaths.append(input_filepath) + os.makedirs(os.path.dirname(input_filepath)) + utils.write_numpy_arrays_to_file(filepath=input_filepath, arrays=inputs[i]) + output_filepath = os.path.join(out_dir, f"shard_{i}", "output.npy") + output_filepaths.append(output_filepath) + return input_filepaths, output_filepaths + + +def run_ranks( + num_ranks: int, + module_filepath: str, + function: str, + inputs: List[List[ArrayLike]], + driver: str, + call_count: int = 1, + measure_execution_time: bool = False, + warmup: int = 0, +) -> List[List[ArrayLike]]: + """ + Start all ranks with mpirun. + On all ranks run the function |function| from the given module. + Parameters + ---------- + inputs : Function inputs for all ranks. + Axis 0 is ranks. Axis 1 is arguments per rank. + Returns + ------- + The output of the function for all ranks. + Axis 0 is ranks. Axis 1 is arguments per rank. + """ + with tempfile.TemporaryDirectory() as out_dir: + input_filepaths, output_filepaths = prepare_shards_io_files( + inputs=inputs, out_dir=out_dir + ) + hal_driver = iree.runtime.get_driver(driver) + hal_driver.query_available_devices() + subprocess.check_call( + [ + "mpirun", + "--oversubscribe", + "-n", + str(num_ranks), + sys.executable, + os.path.join(os.path.dirname(__file__), "run_rank.py"), + f"--driver={driver}", + f"--module_filepath={module_filepath}", + f"--function={function}", + f"--call_count={call_count}", + ] + + (["--measure_execution_time"] if measure_execution_time else []) + + [ + f"--warmup={warmup}", + "--inputs", + ] + + input_filepaths + + ["--outputs"] + + output_filepaths + ) + return [ + utils.read_numpy_arrays_from_file(out_file) for out_file in output_filepaths + ] diff --git a/runtime/bindings/python/iree/runtime/distributed/run_rank.py b/runtime/bindings/python/iree/runtime/distributed/run_rank.py new file mode 100644 index 0000000000000..7ad00f7256cca --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/run_rank.py @@ -0,0 +1,132 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import iree.compiler +import argparse +import iree.runtime +from iree.runtime.array_interop import DeviceArray +from mpi4py import MPI +import utils +import datetime +import numpy as np + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run 1 shard.") + parser.add_argument("--driver", type=str, default="local-task", help="Device URI.") + parser.add_argument( + "--module_filepath", type=str, required=True, help="Path to IREE module." + ) + parser.add_argument( + "--function", type=str, required=True, help="Name of function to call." + ) + parser.add_argument( + "--call_count", + type=int, + default=1, + help="How many times to call the function during time measurement.", + ) + parser.add_argument( + "--measure_execution_time", + action="store_true", + default=False, + help="Measure execution time in seconds f64 and append to results.", + ) + parser.add_argument( + "--warmup", + type=int, + default=0, + help="How many warmup calls to do before the actual call that generates the result.", + ) + parser.add_argument( + "--inputs", + nargs="+", + type=str, + required=True, + help="Path to IREE module inputs for all ranks in npy format.", + ) + parser.add_argument( + "--outputs", + nargs="+", + type=str, + required=True, + help="Path to IREE module outputs form all ranks in npy format.", + ) + return parser.parse_args() + + +def run_module( + device: iree.runtime.HalDevice, + module_filepath: str, + function: str, + call_count: int, + input_filepath: str, + output_filepath: str, + measure_execution_time: bool, + warmup: int, +): + config = iree.runtime.Config(device=device) + with open(module_filepath, "rb") as f: + vm_flatbuffer = f.read() + vm_module = iree.runtime.VmModule.from_flatbuffer(config.vm_instance, vm_flatbuffer) + bound_module = iree.runtime.load_vm_module(vm_module, config) + input_args = utils.read_numpy_arrays_from_file(input_filepath) + input_args_on_device = [ + iree.runtime.asdevicearray(device, arr) for arr in input_args + ] + for _ in range(warmup): + getattr(bound_module, function)(*input_args_on_device) + if measure_execution_time: + # Sync all ranks + MPI.COMM_WORLD.barrier() + start_time = datetime.datetime.now() + assert call_count > 0 + for _ in range(call_count): + results = getattr(bound_module, function)(*input_args_on_device) + if measure_execution_time: + end_time = datetime.datetime.now() + if isinstance(results, DeviceArray): + results = [results] + if measure_execution_time: + if isinstance(results, tuple): + results = list(results) + results.append( + np.array((end_time - start_time).total_seconds() / call_count, dtype=float) + ) + utils.write_numpy_arrays_to_file(filepath=output_filepath, arrays=results) + + +def run_rank( + driver: str, + module_filepath: str, + function: str, + inputs: str, + outputs: str, + call_count: int, + measure_execution_time: bool, + warmup: int, +): + rank = MPI.COMM_WORLD.Get_rank() + hal_driver = iree.runtime.get_driver(driver) + device_infos = hal_driver.query_available_devices() + device = hal_driver.create_device( + device_infos[rank % len(device_infos)]["device_id"] + ) + run_module( + device=device, + module_filepath=module_filepath, + function=function, + call_count=call_count, + input_filepath=inputs[rank], + output_filepath=outputs[rank], + measure_execution_time=measure_execution_time, + warmup=warmup, + ) + + +if __name__ == "__main__": + args = parse_args() + run_rank(**vars(args)) diff --git a/runtime/bindings/python/iree/runtime/distributed/setup.sh b/runtime/bindings/python/iree/runtime/distributed/setup.sh new file mode 100644 index 0000000000000..83dca488caa4f --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/setup.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +set -e + +distribution=$(. /etc/os-release;echo $ID$VERSION_ID | sed -e 's/\.//g') +wget -O /tmp/cuda-keyring_1.0-1_all.deb \ + https://developer.download.nvidia.com/compute/cuda/repos/$distribution/x86_64/cuda-keyring_1.0-1_all.deb +sudo dpkg -i /tmp/cuda-keyring_1.0-1_all.deb +sudo apt update +# For CMake to find CUDA when using LLD. +sudo apt -y install lld + +sudo apt -y install libopenmpi-dev +sudo apt -y install libnccl-dev=2.18.1-1+cuda12.1 +pip install mpi4py jax[cpu] diff --git a/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py b/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py new file mode 100644 index 0000000000000..4465204516237 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py @@ -0,0 +1,210 @@ +import iree.compiler +import iree.runtime +import os +from iree.runtime.distributed import run_ranks +import subprocess +from pathlib import Path +from jax._src.lib import xla_client +from jaxlib.xla_client import HloSharding +from typing import List, Tuple, Union +from numpy.typing import ArrayLike +import jax +from jax._src.sharding_impls import GSPMDSharding +import jax._src.interpreters.pxla as pxla +import numpy as np +from datetime import timedelta + +xla_extension = xla_client._xla + + +def compile_mlir(mlir_filepath: str, output_filepath: str, use_cache: bool, **kwargs): + if use_cache and os.path.exists(output_filepath): + return + os.makedirs(os.path.dirname(output_filepath), exist_ok=True) + iree.compiler.compile_file( + input_file=mlir_filepath, output_file=output_filepath, **kwargs + ) + + +def extract_args_sharding( + xla_computation: xla_extension.XlaComputation, +) -> List[HloSharding]: + return [ + HloSharding.from_proto(sharding) + for sharding in xla_computation.get_hlo_module().spmd_parameters_shardings + ] + + +def extract_results_sharding( + xla_computation: xla_extension.XlaComputation, +) -> List[HloSharding]: + sharding = HloSharding.from_proto( + xla_computation.get_hlo_module().spmd_output_sharding + ) + if len(sharding.tuple_elements()): + return sharding.tuple_elements() + else: + return [sharding] + + +def shard_arg(arg: ArrayLike, sharding: HloSharding) -> List[ArrayLike]: + gspmd_sharding = GSPMDSharding(devices=jax.local_devices(), op_sharding=sharding) + indices = gspmd_sharding.devices_indices_map(arg.shape).values() + sharded_array = pxla.shard_arg( + arg, devices=jax.local_devices(), arg_indices=indices, sharding=gspmd_sharding + ) + return [shard.data for shard in sharded_array.global_shards] + + +def shard_args( + args: List[ArrayLike], shardings: List[HloSharding] +) -> List[List[ArrayLike]]: + assert len(args) == len(shardings) + return [shard_arg(arg, sharding) for arg, sharding in zip(args, shardings)] + + +def assemble_shards(shards: List[ArrayLike], sharding: HloSharding) -> ArrayLike: + if sharding.is_replicated(): + return shards[0] + else: + raise NotImplementedError() + + +def propagate_shardings_and_spmd_partition( + mlir_filepath: str, + output_filepath: str, + num_devices: int, + use_cache: bool, + allow_spmd_sharding_propagation_to_output: int = 1, +): + res = subprocess.run( + [ + "stablehlo-opt", + ( + "--pass-pipeline=builtin.module(stablehlo-xla-sharding-propagation-and-spmd-partitioner{" + "is_spmd=1 " + f"allow_spmd_sharding_propagation_to_output={allow_spmd_sharding_propagation_to_output} " + "allow_spmd_sharding_propagation_to_parameters=1 " + f"num_partitions={num_devices} " + "num_replicas=1})" + ), + mlir_filepath, + ], + check=True, + stdout=subprocess.PIPE, + ) + Path(output_filepath).parent.mkdir(parents=True, exist_ok=True) + if use_cache and os.path.exists(output_filepath): + return + os.makedirs(os.path.dirname(output_filepath), exist_ok=True) + with open(output_filepath, "wb") as f: + f.write(res.stdout) + + +def swap_shard_axis(arrays: List[ArrayLike]) -> List[List[ArrayLike]]: + """Swap axis 0 with 1.""" + if len(arrays) == 0: + return [] + expected_shards = len(arrays[0]) + res = [[] for _ in range(expected_shards)] + for arr in arrays: + assert len(arr) == expected_shards + for shard in range(expected_shards): + res[shard].append(arr[shard]) + return res + + +def execute_distributed( + num_ranks: int, + mlir_filepath: str, + iree_module_filepath: str, + function: str, + inputs: List[ArrayLike], + driver: str, + measure_execution_time: bool = False, +) -> Union[List[ArrayLike], Tuple[List[ArrayLike], timedelta]]: + with open(mlir_filepath, "r") as f: + mlir_str = f.read() + xla_computation = xla_extension.mlir.mlir_module_to_xla_computation( + mlir_module=mlir_str, use_tuple_args=False, return_tuple=False + ) + args_sharding = extract_args_sharding(xla_computation) + results_sharding = extract_results_sharding(xla_computation) + sharded_args = shard_args(args=inputs, shardings=args_sharding) + sharded_args = swap_shard_axis(sharded_args) + sharded_results = run_ranks( + num_ranks=num_ranks, + module_filepath=iree_module_filepath, + function=function, + inputs=sharded_args, + driver=driver, + ) + sharded_results = swap_shard_axis(sharded_results) + if measure_execution_time: + sharded_results, execution_times = sharded_results + res = [ + assemble_shards(shards=result_shards, sharding=sharding) + for result_shards, sharding in zip(sharded_results, results_sharding) + ] + if measure_execution_time: + res = res, timedelta(seconds=np.max(execution_times)) + return res + + +def validate_sharding_passes( + mlir_filepath: str, + mlir_with_sharding_annotations_filepath: str, + inputs: List[ArrayLike], + function: str, + num_devices: int, + use_cache: bool, + driver: str, + target_backend: str, + output_prefix_path: str, + allow_spmd_sharding_propagation_to_output: int = 1, +): + # Single instance. + iree_module_filepath = ( + f"{output_prefix_path}{os.path.basename(mlir_filepath)}.{driver}.vmfb" + ) + compile_mlir( + mlir_filepath=mlir_filepath, + output_filepath=iree_module_filepath, + use_cache=use_cache, + target_backends=[target_backend], + ) + iree_module = iree.runtime.load_vm_flatbuffer_file( + path=iree_module_filepath, driver=driver + ) + results = iree_module[function](*inputs) + if isinstance(results, iree.runtime.DeviceArray): + results = [results] + + # Distributed. + spmd_mlir_filepath = f"{output_prefix_path}{os.path.basename(mlir_with_sharding_annotations_filepath)}.spmd.mlir" + propagate_shardings_and_spmd_partition( + mlir_filepath=mlir_with_sharding_annotations_filepath, + output_filepath=spmd_mlir_filepath, + num_devices=num_devices, + use_cache=use_cache, + allow_spmd_sharding_propagation_to_output=allow_spmd_sharding_propagation_to_output, + ) + spmd_iree_module_filepath = f"{output_prefix_path}{os.path.basename(spmd_mlir_filepath)}.{target_backend}.vmfb" + compile_mlir( + mlir_filepath=spmd_mlir_filepath, + output_filepath=spmd_iree_module_filepath, + use_cache=use_cache, + target_backends=[target_backend], + ) + spmd_results = execute_distributed( + num_ranks=num_devices, + mlir_filepath=spmd_mlir_filepath, + iree_module_filepath=spmd_iree_module_filepath, + function=function, + inputs=inputs, + driver=driver, + ) + + assert len(results) == len(spmd_results) + for result, spmd_result in zip(results, spmd_results): + np.testing.assert_allclose(result, spmd_result, atol=1e-7) diff --git a/runtime/bindings/python/iree/runtime/distributed/utils.py b/runtime/bindings/python/iree/runtime/distributed/utils.py new file mode 100644 index 0000000000000..3581baf354f86 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/utils.py @@ -0,0 +1,26 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from numpy.typing import ArrayLike +from typing import List +import numpy as np + + +def read_numpy_arrays_from_file(filepath: str) -> List[ArrayLike]: + res = [] + with open(filepath, "rb") as f: + while True: + try: + res.append(np.load(f)) + except EOFError: + break + return res + + +def write_numpy_arrays_to_file(filepath: str, arrays: List[ArrayLike]): + with open(filepath, "wb") as f: + for arr in arrays: + np.save(f, np.asarray(arr), allow_pickle=False) From 75745d39e348ecacde1dc2bbae6509e051779463 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 10 Aug 2023 13:13:19 -0700 Subject: [PATCH 22/44] [Distributed] Add example to run a simple model across 2 GPUs --- samples/distributed/example.py | 55 ++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 samples/distributed/example.py diff --git a/samples/distributed/example.py b/samples/distributed/example.py new file mode 100644 index 0000000000000..ff0989403df91 --- /dev/null +++ b/samples/distributed/example.py @@ -0,0 +1,55 @@ +from iree.runtime.distributed import run_ranks +import iree.compiler +import tempfile +import numpy as np +import os + +""" +Example of distributed execution across 2 devices of a small model +with just an all-reduce operation. +all_reduce([1, 2, 3, 4], [5, 6, 7, 8]) -> [6, 8, 10, 12]. + +Dependecies at: +runtime/bindings/python/iree/runtime/distributed/setup.sh +""" +mlir = """ + func.func @all_reduce_sum(%input : tensor<4xf32>) -> tensor<4xf32> { + %out = "stablehlo.all_reduce"(%input) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %sum = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %sum : tensor + }) {channel_handle = #stablehlo.channel_handle, + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + use_global_device_ids} : (tensor<4xf32>) -> tensor<4xf32> + return %out : tensor<4xf32> + } +""" + +inputs = [ + [np.array([1, 2, 3, 4], dtype=np.float32)], + [np.array([5, 6, 7, 8], dtype=np.float32)], +] + +for rank in range(len(inputs)): + print(f"Rank {rank} argument = {inputs[rank]}") + +with tempfile.TemporaryDirectory() as tmp_dir: + module_filepath = os.path.join(tmp_dir, "module.vmfb") + iree.compiler.tools.compile_str( + input_str=mlir, + output_file=module_filepath, + target_backends=["cuda"], + input_type="stablehlo", + ) + + num_ranks = len(inputs) + # Ranks on the 0th axis. + outputs = run_ranks( + num_ranks=num_ranks, + function="all_reduce_sum", + driver="cuda", + module_filepath=module_filepath, + inputs=inputs, + ) + for rank in range(num_ranks): + print(f"Rank {rank} result = {outputs[rank]}") From 5ee1bd89861d25e11a6cfdcb2b034142e04584ce Mon Sep 17 00:00:00 2001 From: Anush Elangovan Date: Wed, 14 Sep 2022 05:31:02 -0700 Subject: [PATCH 23/44] [CI] Switch to GHA linux runners, remove TF builds, move macOS to self-hosted, clean macos bindist Drop instrumented builds and Python < 3.11 Add Upstream sync CI This fixes the problem of potentially dropping commits that have been submitted while an automatic rebase with upstream IREE is goining on. [CI] Fix macos clean up logic Fixes the macos builder. --- .github/workflows/build_package.yml | 26 +++---- .github/workflows/sync.yml | 69 +++++++++++++++++++ .../python_deploy/build_linux_packages.sh | 2 +- build_tools/scripts/get_latest_green.sh | 11 --- 4 files changed, 80 insertions(+), 28 deletions(-) create mode 100644 .github/workflows/sync.yml diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml index ca17fc45f97b0..5aaff26541831 100644 --- a/.github/workflows/build_package.yml +++ b/.github/workflows/build_package.yml @@ -39,35 +39,25 @@ jobs: matrix: include: # Ubuntu packages. - - runs-on: [managed-releaser, os-family=Linux, runner-group=releaser] + - runs-on: ubuntu-latest build-family: linux build-package: main-dist-linux experimental: false - - runs-on: [managed-releaser, os-family=Linux, runner-group=releaser] + - runs-on: ubuntu-latest build-family: linux build-package: py-compiler-pkg experimental: false - - runs-on: [managed-releaser, os-family=Linux, runner-group=releaser] + - runs-on: ubuntu-latest build-family: linux build-package: py-runtime-pkg experimental: false - - runs-on: [managed-releaser, os-family=Linux, runner-group=releaser] - build-family: linux - build-package: py-tf-compiler-tools-pkg - experimental: false - # MacOS packages. - - runs-on: - - ${{ github.repository == 'openxla/iree' && 'self-hosted' || 'macos-11' }} - - os-family=macOS - - runner-group=postsubmit + # Macos packages. + - runs-on: MacStudio build-family: macos build-package: py-compiler-pkg experimental: true - - runs-on: - - ${{ github.repository == 'openxla/iree' && 'self-hosted' || 'macos-11' }} - - os-family=macOS - - runner-group=postsubmit + - runs-on: MacStudio build-family: macos build-package: py-runtime-pkg experimental: true @@ -92,6 +82,10 @@ jobs: path: "c" # Windows can hit path length limits, so use a short path. submodules: true ref: ${{ github.event.inputs.commit }} + - uses: actions/setup-python@v4 + if: "matrix.build-family != 'macos'" + with: + python-version: '3.11' ########################################################################## # OS specific setup diff --git a/.github/workflows/sync.yml b/.github/workflows/sync.yml new file mode 100644 index 0000000000000..2fb41e7b7c685 --- /dev/null +++ b/.github/workflows/sync.yml @@ -0,0 +1,69 @@ +name: 'Sync Upstream' + +on: + workflow_dispatch: + schedule: + - cron: '0 * * * *' + +jobs: + sync_upstream: + name: 'Sync Upstream' + runs-on: ubuntu-latest + steps: + - name: Checking out repository + uses: actions/checkout@v3 + with: + token: ${{ secrets.CI_WRITE_TOKEN }} + repository: nod-ai/shark-runtime + ref: main + fetch-depth: 0 + + - name: Setup git + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "SHARK bot" + + - name: Update main upstream + run: | + set -ex + git remote add upstream https://github.com/iree-org/iree + git pull --ff-only upstream main + + - name: Pushing changes + uses: ad-m/github-push-action@master + with: + github_token: ${{ secrets.CI_WRITE_TOKEN }} + branch: main + repository: nod-ai/shark-runtime + + rebase_shark: + name: 'Rebase SHARK' + runs-on: ubuntu-latest + steps: + - name: Checking out repository + uses: actions/checkout@v3 + with: + token: ${{ secrets.CI_WRITE_TOKEN }} + repository: nod-ai/shark-runtime + ref: shark + fetch-depth: 0 + + - name: Setup git + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "SHARK bot" + + - name: Update shark upstream + run: | + set -ex + git remote add upstream https://github.com/iree-org/iree + git fetch upstream + git rebase upstream/main + + - name: Pushing changes + uses: ad-m/github-push-action@master + with: + github_token: ${{ secrets.CI_WRITE_TOKEN }} + branch: shark + repository: nod-ai/shark-runtime + force_with_lease: true diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 9116db182c2dd..3363d35b421cb 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -65,7 +65,7 @@ this_dir="$(cd $(dirname $0) && pwd)" script_name="$(basename $0)" repo_root=$(cd "${this_dir}" && find_git_dir_parent) manylinux_docker_image="${manylinux_docker_image:-$(uname -m | awk '{print ($1 == "aarch64") ? "quay.io/pypa/manylinux_2_28_aarch64" : "ghcr.io/nod-ai/manylinux_x86_64:main" }')}" -python_versions="${override_python_versions:-cp39-cp39 cp310-cp310 cp311-cp311}" +python_versions="${override_python_versions:-cp311-cp311}" output_dir="${output_dir:-${this_dir}/wheelhouse}" packages="${packages:-iree-runtime iree-compiler}" package_suffix="${package_suffix:-}" diff --git a/build_tools/scripts/get_latest_green.sh b/build_tools/scripts/get_latest_green.sh index 979acb2b6ec09..ea08d08125ce2 100755 --- a/build_tools/scripts/get_latest_green.sh +++ b/build_tools/scripts/get_latest_green.sh @@ -36,17 +36,6 @@ function get_latest_green() { local query_string="$(IFS="&" ; echo "${query_params[*]}")" local all_passing="true" - for workflow in "${REQUIRED_WORKFLOWS[@]}"; do - local successful_run_count="$(\ - gh api --jq '.total_count' \ - "/repos/openxla/iree/actions/workflows/${workflow}/runs?${query_string}" \ - )" - # Any successful run of the workflow (including reruns) is OK. - if (( successful_run_count==0 )); then - all_passing="false" - break - fi - done if [[ "${all_passing}" == true ]]; then echo "${commit}" return 0 From 9d2ba6531db61107bbfd174aafa9b758923ac89f Mon Sep 17 00:00:00 2001 From: powderluv Date: Sat, 3 Jun 2023 06:45:38 -0700 Subject: [PATCH 24/44] [CI] Add AArch64 builder, disable tests --- .github/workflows/build_package.yml | 80 +++++++++++++--- .../validate_and_publish_release.yml | 95 ------------------- compiler/setup.py | 4 +- runtime/setup.py | 2 +- 4 files changed, 71 insertions(+), 110 deletions(-) diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml index 5aaff26541831..894dd4ca98104 100644 --- a/.github/workflows/build_package.yml +++ b/.github/workflows/build_package.yml @@ -39,15 +39,11 @@ jobs: matrix: include: # Ubuntu packages. - - runs-on: ubuntu-latest - build-family: linux - build-package: main-dist-linux - experimental: false - - runs-on: ubuntu-latest + - runs-on: icelake build-family: linux build-package: py-compiler-pkg experimental: false - - runs-on: ubuntu-latest + - runs-on: icelake build-family: linux build-package: py-runtime-pkg experimental: false @@ -73,17 +69,34 @@ jobs: build-package: py-runtime-pkg experimental: true + # Linux AArch64 packages. + - runs-on: linux-aarch64 + build-family: linux-aarch64 + build-package: py-compiler-pkg + experimental: false + - runs-on: linux-aarch64 + build-family: linux-aarch64 + build-package: py-runtime-pkg + experimental: false + + env: MANYLINUX_X86_64_IMAGE: gcr.io/iree-oss/manylinux2014_x86_64-release@sha256:e83893d35be4ce3558c989e9d5ccc4ff88d058bc3e74a83181059cc76e2cf1f8 + MANYLINUX_AARCH64_IMAGE: quay.io/pypa/manylinux_2_28_aarch64 steps: + # Docker may leave root owned files + - name: Chown user + if: "matrix.build-family == 'linux-aarch64' || matrix.build-family == 'linux'" + run: | + sudo chown -R $USER:$USER $GITHUB_WORKSPACE - uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0 with: path: "c" # Windows can hit path length limits, so use a short path. submodules: true ref: ${{ github.event.inputs.commit }} - uses: actions/setup-python@v4 - if: "matrix.build-family != 'macos'" + if: "matrix.build-family == 'windows'" with: python-version: '3.11' @@ -158,7 +171,7 @@ jobs: # One step per OS. ########################################################################## - - name: Build runtime wheels (Linux) + - name: Build runtime wheels (Linux-x86_64) if: "matrix.build-package == 'py-runtime-pkg' && matrix.build-family == 'linux'" shell: bash env: @@ -169,6 +182,16 @@ jobs: [ -e ./bindist/* ] && rm ./bindist/* ./c/build_tools/python_deploy/build_linux_packages.sh + - name: Build runtime wheels (Linux-AArch64) + if: "matrix.build-package == 'py-runtime-pkg' && matrix.build-family == 'linux-aarch64'" + shell: bash + env: + package_suffix: ${{ github.event.inputs.package_suffix }} + packages: "iree-runtime" + output_dir: "${{ github.workspace }}/bindist" + run: | + ./c/build_tools/python_deploy/build_linux_packages.sh + - name: Build runtime wheels (MacOS) if: "matrix.build-package == 'py-runtime-pkg' && matrix.build-family == 'macos'" shell: bash @@ -211,6 +234,16 @@ jobs: [ -e ./bindist/* ] && rm ./bindist/* ./c/build_tools/python_deploy/build_linux_packages.sh + - name: Build compiler wheels (Linux-AArch64) + if: "matrix.build-package == 'py-compiler-pkg' && matrix.build-family == 'linux-aarch64'" + shell: bash + env: + package_suffix: ${{ github.event.inputs.package_suffix }} + packages: "iree-compiler" + output_dir: "${{ github.workspace }}/bindist" + run: | + ./c/build_tools/python_deploy/build_linux_packages.sh + - name: Build compiler wheels (MacOS) if: "matrix.build-package == 'py-compiler-pkg' && matrix.build-family == 'macos'" shell: bash @@ -259,10 +292,10 @@ jobs: path: ./bindist/* retention-days: 5 - # TODO: Upload the tar.bz2 files too when ready - - name: Upload Release Assets - if: github.event.inputs.release_id != '' - id: upload-release-assets + # TODO: One Window Release builds we build both compiler+runtime + - name: Upload Release Assets (Windows) + if: "github.event.inputs.release_id != '' && matrix.build-family == 'windows'" + id: upload-release-assets-windows uses: dwenegar/upload-release-assets@5bc3024cf83521df8ebfadf00ad0c4614fd59148 # v1 env: GITHUB_TOKEN: ${{ secrets.WRITE_ACCESS_TOKEN }} @@ -271,6 +304,29 @@ jobs: # Only upload iree artifacts. assets_path: ./bindist/iree*.* + # TODO: Upload the tar.bz2 files too when ready + - name: Upload Release Assets (Compiler) + if: "github.event.inputs.release_id != '' && matrix.build-package == 'py-compiler-pkg' && matrix.build-family != 'windows'" + id: upload-release-assets-compiler + uses: dwenegar/upload-release-assets@5bc3024cf83521df8ebfadf00ad0c4614fd59148 # v1 + env: + GITHUB_TOKEN: ${{ secrets.WRITE_ACCESS_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + # Only upload iree artifacts. + assets_path: ./bindist/iree_compiler*.* + + - name: Upload Release Assets (Runtime) + if: "github.event.inputs.release_id != '' && matrix.build-package == 'py-runtime-pkg' && matrix.build-family != 'windows'" + id: upload-release-assets-runtime + uses: dwenegar/upload-release-assets@5bc3024cf83521df8ebfadf00ad0c4614fd59148 # v1 + env: + GITHUB_TOKEN: ${{ secrets.WRITE_ACCESS_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + # Only upload iree artifacts. + assets_path: ./bindist/iree_runtime*.* + validate_and_publish: name: "Trigger validate and publish release" needs: build_packages diff --git a/.github/workflows/validate_and_publish_release.yml b/.github/workflows/validate_and_publish_release.yml index 34bc503333b75..e86c2665ea6e2 100644 --- a/.github/workflows/validate_and_publish_release.yml +++ b/.github/workflows/validate_and_publish_release.yml @@ -16,89 +16,8 @@ on: required: true jobs: - validate_packages: - name: "Validate packages" - # TODO(jennik): Look into testing windows and macos builds. - runs-on: ubuntu-20.04 - steps: - - name: Download packages - id: download_packages - uses: dawidd6/action-download-artifact@5e780fc7bbd0cac69fc73271ed86edf5dcb72d67 # v2.26.0 - with: - github_token: ${{secrets.WRITE_ACCESS_TOKEN}} - workflow: build_package.yml - run_id: ${{ github.event.inputs.build_run_id }} - - name: Extract and display downloaded files - run: | - tar -xf artifact/iree-dist-${{ github.event.inputs.package_version }}-linux-x86_64.tar.xz - pwd - ls -R - - name: Set up python - id: set_up_python - uses: actions/setup-python@d27e3f3d7c64b4bbf8e4abfb9b63b83e846e0435 # v4.5.0 - with: - python-version: "3.9" - - name: Install python packages - id: install_python_packages - run: | - python -m pip install -f file://$PWD/artifact/ iree-compiler iree-runtime iree-tools-tflite iree-tools-tf - # Binaries from the tarball - - name: Run iree-benchmark-module - id: run_iree_benchmark_module - run: ./bin/iree-benchmark-module --help - - name: Run iree-benchmark-trace - id: run_iree_benchmark_trace - run: ./bin/iree-benchmark-trace --help - - name: Run iree-dump-module - id: run_iree_dump_module - run: ./bin/iree-dump-module --help - - name: Run iree-cpuinfo - id: run_iree_cpuinfo - run: ./bin/iree-cpuinfo - - name: Run iree-flatcc-cli - id: run_iree_flatcc_cli - run: ./bin/iree-flatcc-cli --help - - name: Run iree-opt - id: run_iree_opt - run: ./bin/iree-opt --help - - name: Run iree-run-mlir - id: run_iree_run_mlir - run: ./bin/iree-run-mlir --help - - name: Run iree-run-module - id: run_iree_run_module - run: ./bin/iree-run-module --help - - name: Run iree-run-trace - id: run_iree_run_trace - run: ./bin/iree-run-trace --help - - name: Run iree-tblgen - id: run_iree_tblgen - run: ./bin/iree-tblgen --help - - name: Run iree-compile - id: run_iree-compile - run: ./bin/iree-compile --help - # Console scripts from the wheels. - - name: Py iree-run-module - id: py_iree-run-module - run: iree-run-module --help - - name: Py iree-run-trace - id: py_iree-run-trace - run: iree-run-trace --help - - name: Py iree-benchmark-module - id: py_iree_benchmark_module - run: iree-benchmark-module --help - - name: Py iree-benchmark-trace - id: py_iree_benchmark_trace - run: iree-benchmark-trace --help - - name: Py iree-dump-module - id: py_iree_dump_module - run: iree-dump-module --help - - name: Py iree-cpuinfo - id: py_iree_cpuinfo - run: iree-cpuinfo - publish_release: name: "Publish release" - needs: validate_packages runs-on: ubuntu-20.04 steps: - name: Publish Release @@ -109,17 +28,3 @@ jobs: with: release_id: ${{ github.event.inputs.release_id }} - - name: Checking out repository - uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0 - with: - token: ${{ secrets.WRITE_ACCESS_TOKEN }} - # Get all history. Otherwise the latest-snapshot branch can't be - # fast-forwarded. - fetch-depth: 0 - - - name: Updating latest-snapshot branch - uses: ad-m/github-push-action@40bf560936a8022e68a3c00e7d2abefaf01305a6 # v0.6.0 - with: - github_token: ${{ secrets.WRITE_ACCESS_TOKEN }} - branch: latest-snapshot - force: true diff --git a/compiler/setup.py b/compiler/setup.py index b49001b94135c..c9ad8468229c7 100644 --- a/compiler/setup.py +++ b/compiler/setup.py @@ -251,11 +251,11 @@ def prepare_installation(): "-GNinja", "--log-level=VERBOSE", "-DIREE_BUILD_PYTHON_BINDINGS=ON", - "-DIREE_BUILD_SAMPLES=OFF", - "-DIREE_BUILD_TESTS=OFF", # Disable .so.0 style symlinking. Python wheels don't preserve links, # so this ~doubles the binary size if not disabled (yikes!). "-DCMAKE_PLATFORM_NO_VERSIONED_SONAME=ON", + "-DIREE_BUILD_TESTS=OFF", + "-DIREE_BUILD_SAMPLES=OFF", "-DPython3_EXECUTABLE={}".format(sys.executable), "-DCMAKE_BUILD_TYPE={}".format(cfg), # TODO(scotttodd): include IREE_TARGET_BACKEND_WEBGPU here (and in env) diff --git a/runtime/setup.py b/runtime/setup.py index 238e72cc2f83b..1fd4cb39a2481 100644 --- a/runtime/setup.py +++ b/runtime/setup.py @@ -275,7 +275,7 @@ def build_configuration(cmake_build_dir, cmake_install_dir, extra_cmake_args=()) "OFF" if platform.system() == "Darwin" else "ON", ), get_env_cmake_list("IREE_EXTERNAL_HAL_DRIVERS", - "" if platform.system() != "Linux" else "rocm"), + "" if sysconfig.get_platform() != "linux-x86_64" else "rocm;level_zero"), get_env_cmake_option("IREE_ENABLE_CPUINFO", "ON"), ] + list(extra_cmake_args) add_env_cmake_setting(cmake_args, "IREE_TRACING_PROVIDER") From b4bed9aa959d4e55cf7e9e32bf17bf4bdde848e5 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 10 Aug 2023 16:25:16 -0700 Subject: [PATCH 25/44] Relax NCCL version constraints Instead of requiring exact NCCL version, relax constraints to the standard ABI versioning rules, namely found_version >= major.minor && found_version < major + 1, where major and minor are from the NCCL headers we use. --- runtime/src/iree/hal/drivers/cuda/cuda_device.c | 8 ++++---- .../src/iree/hal/drivers/cuda/dynamic_symbols.c | 15 ++++++++------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c index 0f39601610180..2c5a91b6caf4e 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c @@ -365,10 +365,10 @@ static iree_status_t iree_hal_cuda_device_create_channel( if (!device->context_wrapper.syms->nccl_library) { return iree_make_status( IREE_STATUS_UNAVAILABLE, - "NCCL runtime library (%d.%d.%d) not available; ensure installed and " - "the shared library is on your PATH/LD_LIBRARY_PATH " - "(nccl.dll/libnccl.so)", - NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH); + "NCCL runtime library version %d.%d and greater not available; " + " ensure installed and the shared library (nccl.dll/libnccl.so) " + "is on your PATH/LD_LIBRARY_PATH.", + NCCL_MAJOR, NCCL_MINOR); } // Today we only allow a single logical device per channel. diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c index 436ce82cda5b9..5e3a62240284d 100644 --- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c +++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c @@ -139,11 +139,12 @@ static iree_status_t iree_hal_cuda_nccl_check_version( minor = (nccl_version % 10000) / 100; } patch = nccl_version % 100; - if (major != NCCL_MAJOR || minor != NCCL_MINOR || patch != NCCL_PATCH) { + int required_minimum_version = NCCL_VERSION(NCCL_MAJOR, NCCL_MINOR, 0); + if (major != NCCL_MAJOR || nccl_version < required_minimum_version) { return iree_make_status( IREE_STATUS_UNAVAILABLE, - "NCCL version is %d.%d.%d, but %d.%d.%d is required", major, minor, - patch, NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH); + "NCCL version is %d.%d.%d, but >=%d.%d and <%d is required", major, + minor, patch, NCCL_MAJOR, NCCL_MINOR, NCCL_MAJOR + 1); } return iree_ok_status(); @@ -174,10 +175,10 @@ iree_status_t iree_hal_cuda_nccl_dynamic_symbols_initialize( iree_status_ignore(status); status = iree_make_status( IREE_STATUS_UNAVAILABLE, - "NCCL runtime library (%d.%d.%d) not available; ensure installed and " - "the shared library is on your PATH/LD_LIBRARY_PATH " - "(nccl.dll/libnccl.so)", - NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH); + "NCCL runtime library version %d.%d and greater not available; " + " ensure installed and the shared library (nccl.dll/libnccl.so) " + "is on your PATH/LD_LIBRARY_PATH.", + NCCL_MAJOR, NCCL_MINOR); } if (iree_status_is_ok(status)) { From 9839c878cb0cb586bd7dc48a0288f3ace71a44b6 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Sat, 12 Aug 2023 12:58:04 -0700 Subject: [PATCH 26/44] Remove dependency of iree.runtime to iree.runtime.distributed --- runtime/bindings/python/iree/runtime/__init__.py | 1 - runtime/bindings/python/iree/runtime/distributed/distributed.py | 1 - runtime/bindings/python/iree/runtime/distributed/run_rank.py | 1 - .../python/iree/runtime/distributed/sharding_pass_validation.py | 2 +- 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py index b398cdda65aeb..140de1010aa5a 100644 --- a/runtime/bindings/python/iree/runtime/__init__.py +++ b/runtime/bindings/python/iree/runtime/__init__.py @@ -57,5 +57,4 @@ from .function import * from .tracing import * -from . import distributed from . import flags diff --git a/runtime/bindings/python/iree/runtime/distributed/distributed.py b/runtime/bindings/python/iree/runtime/distributed/distributed.py index 258e517b2cf23..31e0a5e13a421 100644 --- a/runtime/bindings/python/iree/runtime/distributed/distributed.py +++ b/runtime/bindings/python/iree/runtime/distributed/distributed.py @@ -4,7 +4,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import iree.compiler import sys import iree.runtime from iree.runtime.array_interop import DeviceArray diff --git a/runtime/bindings/python/iree/runtime/distributed/run_rank.py b/runtime/bindings/python/iree/runtime/distributed/run_rank.py index 7ad00f7256cca..86761d3172b2e 100644 --- a/runtime/bindings/python/iree/runtime/distributed/run_rank.py +++ b/runtime/bindings/python/iree/runtime/distributed/run_rank.py @@ -4,7 +4,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import iree.compiler import argparse import iree.runtime from iree.runtime.array_interop import DeviceArray diff --git a/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py b/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py index 4465204516237..599d6604b8a84 100644 --- a/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py +++ b/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py @@ -1,7 +1,7 @@ import iree.compiler import iree.runtime import os -from iree.runtime.distributed import run_ranks +from .distributed import run_ranks import subprocess from pathlib import Path from jax._src.lib import xla_client From 6823d3976e731a8948433efe72aa5c632c40d52f Mon Sep 17 00:00:00 2001 From: powderluv Date: Mon, 14 Aug 2023 22:31:02 -0700 Subject: [PATCH 27/44] Switch windows to self-hosted --- .github/workflows/build_package.yml | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml index 894dd4ca98104..7f7d07aadd2cc 100644 --- a/.github/workflows/build_package.yml +++ b/.github/workflows/build_package.yml @@ -60,11 +60,11 @@ jobs: # Windows packages. - runs-on: - - ${{ github.repository == 'openxla/iree' && 'windows-2022-64core' || 'windows-2022'}} + - ${{ github.repository == 'openxla/iree' && 'windows-2022-64core' || '7950X'}} build-family: windows build-package: py-compiler-pkg experimental: true - - runs-on: windows-2022 + - runs-on: 7950X build-family: windows build-package: py-runtime-pkg experimental: true @@ -104,13 +104,15 @@ jobs: # OS specific setup ########################################################################## - - name: Install dependencies (Windows) - if: "matrix.build-family == 'windows'" - shell: powershell - run: ./c/build_tools/python_deploy/install_windows_deps.ps1 + #- name: Install dependencies (Windows) + # if: "matrix.build-family == 'windows'" + # shell: powershell + # run: ./c/build_tools/python_deploy/install_windows_deps.ps1 - name: "Configure MSVC (Windows)" if: "matrix.build-family == 'windows'" uses: ilammy/msvc-dev-cmd@7315a94840631165970262a99c72cfb48a65d25d # v1.12.0 + with: + arch: x64 ########################################################################## # Write version_info.json From e3de3b981ef680fc3d51b8eae1be81baa4f55137 Mon Sep 17 00:00:00 2001 From: powderluv Date: Tue, 15 Aug 2023 15:02:16 -0700 Subject: [PATCH 28/44] clean up bindist before building --- .github/workflows/build_package.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml index 7f7d07aadd2cc..8774c93abbebc 100644 --- a/.github/workflows/build_package.yml +++ b/.github/workflows/build_package.yml @@ -192,6 +192,7 @@ jobs: packages: "iree-runtime" output_dir: "${{ github.workspace }}/bindist" run: | + [ -e ./bindist/* ] && rm ./bindist/* ./c/build_tools/python_deploy/build_linux_packages.sh - name: Build runtime wheels (MacOS) From cdcb47d7fa161a40ef6963807a76d0dfa791e9c4 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Wed, 16 Aug 2023 18:07:36 -0700 Subject: [PATCH 29/44] [LevelZero] remove intial data argument form buffer allocation Makes the driver compliant with the HAL API change. --- experimental/level_zero/level_zero_allocator.c | 14 +------------- experimental/level_zero/level_zero_device.c | 6 +++--- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/experimental/level_zero/level_zero_allocator.c b/experimental/level_zero/level_zero_allocator.c index 04d64119b51c3..c18d7eed5c9b3 100644 --- a/experimental/level_zero/level_zero_allocator.c +++ b/experimental/level_zero/level_zero_allocator.c @@ -192,7 +192,7 @@ static void iree_hal_level_zero_buffer_free( static iree_status_t iree_hal_level_zero_allocator_allocate_buffer( iree_hal_allocator_t* IREE_RESTRICT base_allocator, const iree_hal_buffer_params_t* IREE_RESTRICT params, - iree_device_size_t allocation_size, iree_const_byte_span_t initial_data, + iree_device_size_t allocation_size, iree_hal_buffer_t** IREE_RESTRICT out_buffer) { iree_hal_level_zero_allocator_t* allocator = iree_hal_level_zero_allocator_cast(base_allocator); @@ -267,18 +267,6 @@ static iree_status_t iree_hal_level_zero_allocator_allocate_buffer( /*byte_length=*/allocation_size, device_ptr, host_ptr, &buffer); } - // Copy the initial contents into the buffer. This may require staging. - if (iree_status_is_ok(status) && - !iree_const_byte_span_is_empty(initial_data)) { - status = iree_hal_device_transfer_range( - allocator->base_device, - iree_hal_make_host_transfer_buffer_span((void*)initial_data.data, - initial_data.data_length), - 0, iree_hal_make_device_transfer_buffer(buffer), 0, - initial_data.data_length, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, - iree_infinite_timeout()); - } - if (iree_status_is_ok(status)) { IREE_STATISTICS(iree_hal_allocator_statistics_record_alloc( &allocator->statistics, compat_params.type, allocation_size)); diff --git a/experimental/level_zero/level_zero_device.c b/experimental/level_zero/level_zero_device.c index d27abd164e69a..ff7ebf1b01010 100644 --- a/experimental/level_zero/level_zero_device.c +++ b/experimental/level_zero/level_zero_device.c @@ -340,9 +340,9 @@ static iree_status_t iree_hal_level_zero_device_queue_alloca( // TODO(benvanik): queue-ordered allocations. IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_wait(wait_semaphore_list, iree_infinite_timeout())); - IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( - iree_hal_device_allocator(base_device), params, allocation_size, - iree_const_byte_span_empty(), out_buffer)); + IREE_RETURN_IF_ERROR( + iree_hal_allocator_allocate_buffer(iree_hal_device_allocator(base_device), + params, allocation_size, out_buffer)); IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_signal(signal_semaphore_list)); return iree_ok_status(); } From 54aa9faad303c35a9deb42015f9afe254e9cc091 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Wed, 23 Aug 2023 20:20:10 -0400 Subject: [PATCH 30/44] [SPIRV] Reduce the number of warps used by subgroup reduce Currently the number of subgroups to use is driven by a target vector size, which for large reductions can end up translating to a large number of subgroups. This adds a preferred unrolling factor on the vector size to reduce the default number of subgroups. --- compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp | 4 +++- .../compiler/Codegen/SPIRV/test/config_default_matvec.mlir | 2 +- .../compiler/Codegen/SPIRV/test/config_default_reduction.mlir | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index 1579eabf27dc3..ec3bb090ec50f 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -40,6 +40,7 @@ using llvm::APIntOps::GreatestCommonDivisor; constexpr unsigned numTilesPerSubgroupDimK = 2; constexpr int kMaxVectorNumBits = 128; +constexpr int kPreferredReductionVectorUnrollAmount = 8; namespace mlir { namespace iree_compiler { @@ -1259,7 +1260,8 @@ static LogicalResult setReductionConfig(const spirv::TargetEnv &targetEnv, return failure(); // Let each thread handle `vectorSize` elements. - unsigned vectorSize = kMaxVectorNumBits / bitWidth; + unsigned vectorSize = + kPreferredReductionVectorUnrollAmount * kMaxVectorNumBits / bitWidth; while ((dimSize / vectorSize) % subgroupSize != 0) vectorSize /= 2; diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir index f60f5a8d4bdbc..4cbe804a85224 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir @@ -140,7 +140,7 @@ hal.executable @i4_dequant_matvec_f32 { // CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK-LABEL: hal.executable.export public @i4_dequant_matvec_f32 // CHECK-SAME: translation_info = #[[$TRANSLATION]] -// CHECK-SAME: workgroup_size = [1024 : index, 1 : index, 1 : index] +// CHECK-SAME: workgroup_size = [128 : index, 1 : index, 1 : index] // CHECK: func.func @i4_dequant_matvec_f32() // CHECK: linalg.generic // CHECK-SAME: lowering_config = #[[$CONFIG]] diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir index 7b08b772ddf1f..896b735aa02bb 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir @@ -47,7 +47,7 @@ hal.executable private @subgroup_reduce_f32 { // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @subgroup_reduce_f32 // CHECK-SAME: translation_info = #[[TRANSLATION]] -// CHECK-SAME: workgroup_size = [128 : index, 1 : index, 1 : index] +// CHECK-SAME: workgroup_size = [16 : index, 1 : index, 1 : index] // CHECK: func.func @subgroup_reduce_f32() // CHECK: linalg.generic // CHECK-SAME: lowering_config = #[[CONFIG]] @@ -105,7 +105,7 @@ hal.executable private @subgroup_reduce_f16 { // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @subgroup_reduce_f16 // CHECK-SAME: translation_info = #[[TRANSLATION]] -// CHECK-SAME: workgroup_size = [512 : index, 1 : index, 1 : index] +// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index] // CHECK: func.func @subgroup_reduce_f16() // CHECK: linalg.generic // CHECK-SAME: lowering_config = #[[CONFIG]] From f79be87f02943c9498055a1319650da09bd0b26b Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 25 Aug 2023 03:33:22 -0400 Subject: [PATCH 31/44] Revert "[codegen][spirv] Pack/transpose matrix B for better coop mmma" This reverts commit a6512dc5ba99912eaa24cd0b8a5892db1b38e053. --- .../Codegen/SPIRV/SPIRVTileAndDistribute.cpp | 4 +- .../Flow/Transforms/FormDispatchRegions.cpp | 9 +- .../Dialect/Flow/Transforms/Passes.cpp | 5 - .../compiler/Preprocessing/Common/BUILD.bazel | 3 +- .../Preprocessing/Common/CMakeLists.txt | 1 - .../Common/ConvertLinalgMatmulToMmt.cpp | 119 -------------- .../Common/GeneralizeAndFuse.cpp | 148 ------------------ .../compiler/Preprocessing/Common/Passes.h | 6 - .../compiler/Preprocessing/Common/Passes.td | 12 -- 9 files changed, 4 insertions(+), 303 deletions(-) delete mode 100644 compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp delete mode 100644 compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp index 7df6629b0b4f9..260f8bb7857ff 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp @@ -98,8 +98,8 @@ static void populateTilingReductionPatterns( .setLoopType(linalg::LinalgTilingLoopType::Loops) .setTileSizeComputationFunction(computeFn); - TilingPatterns::insert(patterns, tilingOptions, filter); + TilingPatterns::insert( + patterns, tilingOptions, filter); filter.addFilter([](Operation *op) { return success(isa(op)); }); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp index fbc66e9175d29..ea00c12bb53a5 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp @@ -653,14 +653,7 @@ isFusableWithProducer(OpOperand &operand, } auto consumerLinalgOp = cast(consumer); - if (consumerLinalgOp.isDpsInput(&operand)) { - // TODO: Add some marker on transpose and MatmulOp to indicate mmt. - bool fuseTransposeAndMatmul = - isa(consumer) && isa(producer); - if (fuseTransposeAndMatmul) { - return true; - } - } else if (!consumerLinalgOp.isDpsInit(&operand)) { + if (!consumerLinalgOp.isDpsInit(&operand)) { return false; } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index b8394b36f7b08..6601d5b815d3c 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -80,11 +80,6 @@ static llvm::cl::opt clDispatchGenerateWorkloadRegion( llvm::cl::desc("Generate the workload region."), llvm::cl::init(true)); -static llvm::cl::opt clEnableTransposeMatmulLayout( - "iree-flow-enable-transpose-matmul-layout", - llvm::cl::desc("Enable transposing the B matrix for matmuls."), - llvm::cl::init(false)); - static llvm::cl::opt clNormalizeInputIndexingMap( "iree-flow-normalize-input-indexing-map", llvm::cl::desc("Enable normalizing input indexing map to identity."), diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index 08795e578b27b..7c43380e1f662 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -34,10 +34,9 @@ iree_compiler_cc_library( "ConvertConvNchwToNhwc.cpp", "ConvertConvToChannelsLast.cpp", "ConvertLinalgMatmulToMmt.cpp", - "GeneralizeAndFuse.cpp", "GeneralizeConvolutions.cpp", "MakeSingleDispatchForFunction.cpp", - "PadLinalgOps.cpp", + "PadLinalgOps.cpp", "PassDetail.h", "Passes.cpp", ], diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index 1b83bc1ea1d74..a147974d907dd 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -30,7 +30,6 @@ iree_cc_library( "ConvertConvNchwToNhwc.cpp" "ConvertConvToChannelsLast.cpp" "ConvertLinalgMatmulToMmt.cpp" - "GeneralizeAndFuse.cpp" "GeneralizeConvolutions.cpp" "MakeSingleDispatchForFunction.cpp" "PadLinalgOps.cpp" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp deleted file mode 100644 index f22e55cf92ac0..0000000000000 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertLinalgMatmulToMmt.cpp +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include -#include - -#include "iree/compiler/Preprocessing/Common/PassDetail.h" -#include "iree/compiler/Preprocessing/Common/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { - -namespace { - -// Converts linalg.matmul to an linalg.transpose + linalg.matmul. -// Such that matrix B layout changes to col major. -class LinalgMatmulOpToLinalgMmtPattern final - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp, - PatternRewriter &rewriter) const override { - Location loc = matmulOp.getLoc(); - Value lhs = matmulOp.getDpsInputOperand(0)->get(); - Value rhs = matmulOp.getDpsInputOperand(1)->get(); - Value acc = matmulOp.getDpsInitOperand(0)->get(); - if (dyn_cast(rhs.getDefiningOp())) { - return failure(); - } - auto rhsType = rhs.getType().cast(); - auto rhsShape = rhsType.getShape(); - auto rhsElemType = rhsType.getElementType(); - SmallVector transposedRhsShape = {rhsShape[1], rhsShape[0]}; - - // GenericOp - int64_t nloops = rhsShape.size(); - AffineExpr mDim, nDim; - bindDims(getContext(), mDim, nDim); - auto inputMap = AffineMap::get(2, 0, {mDim, nDim}, getContext()); - auto packedMap = AffineMap::get(2, 0, {nDim, mDim}, getContext()); - SmallVector indexingMaps = {inputMap, packedMap}; - - Value transposedRhs = - rewriter.create(loc, transposedRhsShape, rhsElemType); - SmallVector loopAttributeTypes( - nloops, utils::IteratorType::parallel); - - Value packedRhs = - rewriter - .create( - loc, transposedRhs.getType(), - /*inputs=*/rhs, /*outputs=*/transposedRhs, indexingMaps, - loopAttributeTypes, - [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange args) { - nestedBuilder.create(nestedLoc, args[0]); - }) - .getResult(0); - - // TransposeOp - Value initOp = rewriter.create(loc, rhsShape, rhsElemType); - SmallVector transposedPerm = {1, 0}; - Value transposePackedRhs = - rewriter - .create(loc, packedRhs, initOp, transposedPerm) - .getResults()[0]; - - // MatmulOp - Value packedMatmul = - rewriter - .create(loc, matmulOp.getResult(0).getType(), - ArrayRef{lhs, transposePackedRhs}, - ArrayRef{acc}) - .getResult(0); - rewriter.replaceOp(matmulOp, packedMatmul); - return success(); - } -}; - -struct ConvertLinalgMatmulToMmtPass - : public ConvertLinalgMatmulToMmtBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext *context = &getContext(); - // Main pattern. - { - RewritePatternSet patterns(&getContext()); - patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } - } -}; -} // namespace - -std::unique_ptr createConvertLinalgMatmulToMmtPass() { - return std::make_unique(); -} - -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp deleted file mode 100644 index 8a4e7bee31e3d..0000000000000 --- a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeAndFuse.cpp +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include -#include - -#include "iree/compiler/Preprocessing/Common/PassDetail.h" -#include "iree/compiler/Preprocessing/Common/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { - -namespace { - -//===----------------------------------------------------------------------===// -// Utility Functions -//===----------------------------------------------------------------------===// - -bool isMatmulOrBatchMatmul(linalg::LinalgOp linalgOp) { - return linalg::isaContractionOpInterface(linalgOp) && - llvm::is_contained({2u, 3u}, linalgOp.getNumParallelLoops()); -} - -//===----------------------------------------------------------------------===// -// Generalize and fusion patterns. -//===----------------------------------------------------------------------===// - -struct GeneralizeAndFusePass - : public GeneralizeAndFuseBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - template - class GeneralizeTargetNamedOpPattern final - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(LinalgOpType linalgOp, - PatternRewriter &rewriter) const override { - // TODO: Check consumer is transposeOp. - // TODO: Generalize transpos - FailureOr genericOp = - linalg::generalizeNamedOp(rewriter, linalgOp); - if (failed(genericOp)) return failure(); - return success(); - } - }; - - class FuseMatmulAndTranspose final - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - // Inspo: - // https://github.com/llvm/llvm-project/blob/4f1c12425179608298dc39f5524ba2612609b5e4/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp - LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, - PatternRewriter &rewriter) const override { - const unsigned rhsId = 1; - if (!isMatmulOrBatchMatmul(linalgOp)) return failure(); - Value rhs = linalgOp.getDpsInputOperand(rhsId)->get(); - auto transposeOp = dyn_cast(rhs.getDefiningOp()); - if (!transposeOp) return failure(); - auto perm = transposeOp.getPermutation(); - auto indexingMaps = linalgOp.getIndexingMaps(); - auto rhsMap = indexingMaps[rhsId].cast().getValue(); - int64_t rank = perm.size(); - if (rhsMap.getNumResults() != rank) return failure(); - SmallVector exprs; - for (auto dim_id : perm) { - exprs.push_back(rhsMap.getResult(dim_id)); - } - AffineMap transposedRhsMap = - AffineMap::get(rhsMap.getNumDims(), 0, exprs, getContext()); - - // TODO: Fold transposeOp as transposed indexing for matmulOp. - // Generate a map set. - auto lhsMap = indexingMaps[0].cast().getValue(); - auto accMap = indexingMaps[2].cast().getValue(); - SmallVector newIndexingMaps = {lhsMap, transposedRhsMap, - accMap}; - - // Generate new list of args. - Value newRhs = transposeOp.getDpsInputOperand(0)->get(); - Value lhs = linalgOp.getDpsInputOperand(0)->get(); - Value acc = linalgOp.getDpsInitOperand(0)->get(); - SmallVector inputs = {lhs, newRhs}; - - // Generate a new genericOp. - linalg::GenericOp genericOp = rewriter.create( - linalgOp.getLoc(), linalgOp.getResultTypes(), /*inputs*/ inputs, - /*outputs*/ acc, newIndexingMaps, linalgOp.getIteratorTypesArray()); - // Block consumerBlock = linalgOp->getRegion(0).front(); - // genericOp.getRegion().push_back(consumerBlock); - // llvm::outs()<<"new op - // regions:"<getNumRegions()<<"\n"; - // llvm::outs()<<"new op - // regions:"<getNumRegions()<<"\n"; - // llvm::outs()<<"new op - // blocks:"<getNumRegions()<<"\n"; - // llvm::outs()<<"old op - // blocks:"<getRegion(0), genericOp.getRegion(), - genericOp.getRegion().begin()); - rewriter.replaceOp(linalgOp, genericOp->getResults()); - return success(); - } - }; - - void runOnOperation() override { - MLIRContext *context = &getContext(); - // Main pattern. - // Generalize + Fuse pattern. - { - RewritePatternSet patterns(&getContext()); - patterns.insert, - FuseMatmulAndTranspose>(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } - } -}; -} // namespace - -std::unique_ptr createGeneralizeAndFusePass() { - return std::make_unique(); -} - -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h index a339fb37d5a2c..0680d493965e5 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h @@ -26,12 +26,6 @@ std::unique_ptr createConvertConv2DToImg2ColPass(); std::unique_ptr> createConvertConvNchwToNhwcPass(); -// Pass to convert a linalg.matmul into linalg.transpose + linalg.matmul. -std::unique_ptr createConvertLinalgMatmulToMmtPass(); - -// Generalizes named op and try to fuse them -std::unique_ptr createGeneralizeAndFusePass(); - /// Moves the body of the entire function into a single dispatch. std::unique_ptr> createMakeSingleDispatchForFunctionPass(); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index 2948aaaab929c..1fe414d990970 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -22,18 +22,6 @@ def ConvertConvNchwToNhwc : "mlir::iree_compiler::IREE::createConvertConvNchwToNhwcPass()"; } -def ConvertLinalgMatmulToMmt : - Pass<"iree-flow-convert-linalg-matmul-to-mmt", ""> { - let summary = "Convert linalg.matmul to linalg.transpose + linalg.matmul"; - let constructor = "mlir::iree_compiler::IREE::createConvertLinalgMatmulToMmtPass()"; -} - -def GeneralizeAndFuse : - Pass<"iree-flow-generalize-and-fuse", ""> { - let summary = "Generalizes named op and try to fuse them."; - let constructor = "mlir::iree_compiler::IREE::createGeneralizeAndFusePass()"; -} - def MakeSingleDispatchForFunction : Pass<"iree-preprocessing-make-single-dispatch-for-function", "func::FuncOp"> { let summary = "Convert entire function into a single dispatch"; From 1b45caf652d5f2c50d82e952e38503cdf1663e3a Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 25 Aug 2023 13:45:20 +0000 Subject: [PATCH 32/44] Remove ConvertLinalgMatmulToMmt completely --- compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel | 1 - compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt | 1 - 2 files changed, 2 deletions(-) diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index 7c43380e1f662..f33d8712093bb 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -33,7 +33,6 @@ iree_compiler_cc_library( "ConvertConv2DToImg2Col.cpp", "ConvertConvNchwToNhwc.cpp", "ConvertConvToChannelsLast.cpp", - "ConvertLinalgMatmulToMmt.cpp", "GeneralizeConvolutions.cpp", "MakeSingleDispatchForFunction.cpp", "PadLinalgOps.cpp", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index a147974d907dd..09e61373aed9d 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -29,7 +29,6 @@ iree_cc_library( "ConvertConv2DToImg2Col.cpp" "ConvertConvNchwToNhwc.cpp" "ConvertConvToChannelsLast.cpp" - "ConvertLinalgMatmulToMmt.cpp" "GeneralizeConvolutions.cpp" "MakeSingleDispatchForFunction.cpp" "PadLinalgOps.cpp" From b7b23cecd9cef25e1f7536752133dc976e8eeb40 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 24 Aug 2023 01:11:37 -0500 Subject: [PATCH 33/44] Add hip headers to build ROCm backend without the SDK. --- .../python_deploy/build_linux_packages.sh | 2 + .../python_deploy/build_windows_packages.ps1 | 2 + experimental/rocm/CMakeLists.txt | 4 +- experimental/rocm/hip_headers.h | 1489 +++++++++++++++++ experimental/rocm/rocm_headers.h | 2 +- 5 files changed, 1496 insertions(+), 3 deletions(-) create mode 100644 experimental/rocm/hip_headers.h diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 3363d35b421cb..c2739fc905854 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -157,10 +157,12 @@ function build_iree_runtime() { export IREE_RUNTIME_BUILD_TRACY=ON # We install the needed build deps below for the tools. export IREE_RUNTIME_BUILD_TRACY_TOOLS=ON + export IREE_EXTERNAL_HAL_DRIVERS="rocm" build_wheel runtime/ } function build_iree_compiler() { + export IREE_TARGET_BACKEND_ROCM=ON build_wheel compiler/ } diff --git a/build_tools/python_deploy/build_windows_packages.ps1 b/build_tools/python_deploy/build_windows_packages.ps1 index 43906f8800f71..a8fbd3a9aadf9 100644 --- a/build_tools/python_deploy/build_windows_packages.ps1 +++ b/build_tools/python_deploy/build_windows_packages.ps1 @@ -67,11 +67,13 @@ function run() { function build_iree_runtime() { param($python_version) $env:IREE_HAL_DRIVER_VULKAN = "ON" + $env:IREE_EXTERNAL_HAL_DRIVERS = "rocm" & py -${python_version} -m pip wheel -v -w $output_dir $repo_root/runtime/ } function build_iree_compiler() { param($python_version) + $env:IREE_TARGET_BACKEND_ROCM= "ON" py -${python_version} -m pip wheel -v -w $output_dir $repo_root/compiler/ } diff --git a/experimental/rocm/CMakeLists.txt b/experimental/rocm/CMakeLists.txt index c2b313ddd0e30..8b3c3b9eb2e2a 100644 --- a/experimental/rocm/CMakeLists.txt +++ b/experimental/rocm/CMakeLists.txt @@ -25,7 +25,7 @@ set(IREE_ROCM_BC_DIR "${IREE_ROCM_BC_DIR_DEFAULT}" CACHE STRING iree_add_all_subdirs() if(NOT ROCM_HEADERS_API_ROOT) - set(ROCM_HEADERS_API_ROOT "/opt/rocm/include") + set(ROCM_HEADERS_API_ROOT ${CMAKE_CURRENT_LIST_DIR}) endif() if(EXISTS ${ROCM_HEADERS_API_ROOT}) @@ -68,7 +68,7 @@ iree_cc_library( INCLUDES "${CMAKE_CURRENT_LIST_DIR}/../.." "${PROJECT_BINARY_DIR}" - "${ROCM_HEADERS_API_ROOT}" + "${ROCM_HEADERS_API_ROOT}/.." DEPS ::dynamic_symbols iree::base diff --git a/experimental/rocm/hip_headers.h b/experimental/rocm/hip_headers.h new file mode 100644 index 0000000000000..28bcd5667482b --- /dev/null +++ b/experimental/rocm/hip_headers.h @@ -0,0 +1,1489 @@ +/* + * Copyright 2011-2021 Blender Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +#ifndef __HIPEW_H__ +#define __HIPEW_H__ + +#ifdef __cplusplus +extern "C" { +#endif + +#include + +#define HIP_IPC_HANDLE_SIZE 64 +#define hipHostMallocDefault 0x00 +#define hipHostMallocPortable 0x01 +#define hipHostMallocMapped 0x02 +#define hipHostMallocWriteCombined 0x04 +#define hipHostMallocNumaUser 0x20000000 +#define hipHostMallocCoherent 0x40000000 +#define hipHostMallocNonCoherent 0x80000000 +#define hipHostRegisterPortable 0x01 +#define hipHostRegisterMapped 0x02 +#define hipHostRegisterIoMemory 0x04 +#define hipCooperativeLaunchMultiDeviceNoPreSync 0x01 +#define hipCooperativeLaunchMultiDeviceNoPostSync 0x02 +#define hipArrayLayered 0x01 +#define hipArraySurfaceLoadStore 0x02 +#define hipArrayCubemap 0x04 +#define hipArrayTextureGather 0x08 +#define HIP_TRSA_OVERRIDE_FORMAT 0x01 +#define HIP_TRSF_READ_AS_INTEGER 0x01 +#define HIP_TRSF_NORMALIZED_COORDINATES 0x02 +#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01) +#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02) +#define HIP_LAUNCH_PARAM_END ((void*)0x03) + +/* Functions which changed 3.1 -> 3.2 for 64 bit stuff, + * the cuda library has both the old ones for compatibility and new + * ones with _v2 postfix, + */ +#define hipModuleGetGlobal hipModuleGetGlobal +#define hipMemGetInfo hipMemGetInfo +#define hipMemAllocPitch hipMemAllocPitch +#define hipMemGetAddressRange hipMemGetAddressRange +#define hipMemcpy hipMemcpy +#define hipMemcpyHtoD hipMemcpyHtoD +#define hipMemcpyDtoH hipMemcpyDtoH +#define hipMemcpyDtoD hipMemcpyDtoD +#define hipMemcpyHtoA hipMemcpyHtoA +#define hipMemcpyAtoH hipMemcpyAtoH +/* +* #define hipMemcpyAsync hipMemcpyAsync +* #define hipMemcpyHtoDAsync hipMemcpyHtoDAsync +* #define hipMemcpyDtoHAsync hipMemcpyDtoHAsync +* #define hipMemcpyDtoDAsync hipMemcpyDtoDAsync +*/ +#define hipMemsetD8 hipMemsetD8 +#define hipMemsetD16 hipMemsetD16 +#define hipMemsetD32 hipMemsetD32 +#define hipMemsetAsync hipMemsetAsync +#define hipMemsetD8Async hipMemsetD8Async +#define hipMemsetD16Async hipMemsetD16Async +#define hipMemsetD32Async hipMemsetD32Async +#define hipArrayCreate hipArrayCreate +#define hipArray3DCreate hipArray3DCreate +#define hipPointerGetAttributes hipPointerGetAttributes +#define hipTexRefSetAddress hipTexRefSetAddress +#define hipTexRefGetAddress hipTexRefGetAddress +#define hipStreamDestroy hipStreamDestroy +#define hipEventDestroy hipEventDestroy +#define hipTexRefSetAddress2D hipTexRefSetAddress2D + +/* Types. */ +#ifdef _MSC_VER +typedef unsigned __int32 hipuint32_t; +typedef unsigned __int64 hipuint64_t; +#else +#include +typedef uint32_t hipuint32_t; +typedef uint64_t hipuint64_t; +#endif + +#if defined(__x86_64) || defined(AMD64) || defined(_M_AMD64) || defined (__aarch64__) +/* + * Changed to void* for Windows / MSVC + * typedef unsigned long long hipDeviceptr_t; + */ +typedef void * hipDeviceptr_t; +#else +typedef unsigned int hipDeviceptr_t; +#endif + + +#ifdef _WIN32 +# define HIPAPI __stdcall +# define HIP_CB __stdcall +#else +# define HIPAPI +# define HIP_CB +#endif + +typedef int hipDevice_t; +typedef struct ihipCtx_t* hipCtx_t; +typedef struct ihipModule_t* hipModule_t; +typedef struct ihipModuleSymbol_t* hipFunction_t; +typedef struct hipArray* hArray; +typedef struct hipMipmappedArray_st* hipMipmappedArray_t; +typedef struct ihipEvent_t* hipEvent_t; +typedef struct ihipStream_t* hipStream_t; +typedef unsigned long long hipTextureObject_t; +typedef void* hipExternalMemory_t; + +typedef struct HIPuuid_st { + char bytes[16]; +} HIPuuid; + +typedef enum hipMemcpyKind { + hipMemcpyHostToHost = 0, + hipMemcpyHostToDevice = 1, + hipMemcpyDeviceToHost = 2, + hipMemcpyDeviceToDevice = 3, + hipMemcpyDefault = 4 +} hipMemcpyKind; + +typedef enum hipChannelFormatKind { + hipChannelFormatKindSigned = 0, + hipChannelFormatKindUnsigned = 1, + hipChannelFormatKindFloat = 2, + hipChannelFormatKindNone = 3, +}hipChannelFormatKind; + +typedef struct hipChannelFormatDesc { + int x; + int y; + int z; + int w; + enum hipChannelFormatKind f; +}hipChannelFormatDesc; + +typedef enum hipTextureFilterMode { + hipFilterModePoint = 0, + hipFilterModeLinear = 1, +} hipTextureFilterMode; + +typedef enum hipArray_Format { + HIP_AD_FORMAT_UNSIGNED_INT8 = 0x01, + HIP_AD_FORMAT_SIGNED_INT8 = 0x08, + HIP_AD_FORMAT_UNSIGNED_INT16 = 0x02, + HIP_AD_FORMAT_SIGNED_INT16 = 0x09, + HIP_AD_FORMAT_UNSIGNED_INT32 = 0x03, + HIP_AD_FORMAT_SIGNED_INT32 = 0x0a, + HIP_AD_FORMAT_HALF = 0x10, + HIP_AD_FORMAT_FLOAT = 0x20, +} hipArray_Format; + +typedef enum hipTextureAddressMode { + hipAddressModeWrap = 0, + hipAddressModeClamp = 1, + hipAddressModeMirror = 2, + hipAddressModeBorder = 3, +} hipTextureAddressMode; + +/** + * hip texture reference + */ +typedef struct textureReference { + int normalized; + //enum hipTextureReadMode readMode;// used only for driver API's + enum hipTextureFilterMode filterMode; + enum hipTextureAddressMode addressMode[3]; // Texture address mode for up to 3 dimensions + struct hipChannelFormatDesc channelDesc; + int sRGB; // Perform sRGB->linear conversion during texture read + unsigned int maxAnisotropy; // Limit to the anisotropy ratio + enum hipTextureFilterMode mipmapFilterMode; + float mipmapLevelBias; + float minMipmapLevelClamp; + float maxMipmapLevelClamp; + + hipTextureObject_t textureObject; + int numChannels; + enum hipArray_Format format; +}textureReference; + +typedef textureReference* hipTexRef; + +typedef enum hipMemoryType { + hipMemoryTypeHost = 0x00, + hipMemoryTypeDevice = 0x01, + hipMemoryTypeArray = 0x02, + hipMemoryTypeUnified = 0x03, +} hipMemoryType; + +/** + * Pointer attributes + */ +typedef struct hipPointerAttribute_t { + enum hipMemoryType memoryType; + int device; + void* devicePointer; + void* hostPointer; + int isManaged; + unsigned allocationFlags; /* flags specified when memory was allocated*/ + /* peers? */ +} hipPointerAttribute_t; + +typedef struct ihipIpcEventHandle_t { + char reserved[HIP_IPC_HANDLE_SIZE]; +} ihipIpcEventHandle_t; + +typedef struct hipIpcMemHandle_st { + char reserved[HIP_IPC_HANDLE_SIZE]; +} hipIpcMemHandle_t; + +typedef enum HIPipcMem_flags_enum { + hipIpcMemLazyEnablePeerAccess = 0x1, +} HIPipcMem_flags; + +typedef enum HIPmemAttach_flags_enum { + hipMemAttachGlobal = 0x1, + hipMemAttachHost = 0x2, + HIP_MEM_ATTACH_SINGLE = 0x4, +} HIPmemAttach_flags; + +typedef enum HIPctx_flags_enum { + hipDeviceScheduleAuto = 0x00, + hipDeviceScheduleSpin = 0x01, + hipDeviceScheduleYield = 0x02, + hipDeviceScheduleBlockingSync = 0x04, + hipDeviceScheduleMask = 0x07, + hipDeviceMapHost = 0x08, + hipDeviceLmemResizeToMax = 0x10, +} HIPctx_flags; + +typedef enum HIPstream_flags_enum { + hipStreamDefault = 0x0, + hipStreamNonBlocking = 0x1, +} HIPstream_flags; + +typedef enum HIPevent_flags_enum { + hipEventDefault = 0x0, + hipEventBlockingSync = 0x1, + hipEventDisableTiming = 0x2, + hipEventInterprocess = 0x4, +} HIPevent_flags; + +typedef enum HIPstreamWaitValue_flags_enum { + HIP_STREAM_WAIT_VALUE_GEQ = 0x0, + HIP_STREAM_WAIT_VALUE_EQ = 0x1, + HIP_STREAM_WAIT_VALUE_AND = 0x2, + HIP_STREAM_WAIT_VALUE_NOR = 0x3, + HIP_STREAM_WAIT_VALUE_FLUSH = (1 << 30), +} HIPstreamWaitValue_flags; + +typedef enum HIPstreamWriteValue_flags_enum { + HIP_STREAM_WRITE_VALUE_DEFAULT = 0x0, + HIP_STREAM_WRITE_VALUE_NO_MEMORY_BARRIER = 0x1, +} HIPstreamWriteValue_flags; + +typedef enum HIPstreamBatchMemOpType_enum { + HIP_STREAM_MEM_OP_WAIT_VALUE_32 = 1, + HIP_STREAM_MEM_OP_WRITE_VALUE_32 = 2, + HIP_STREAM_MEM_OP_WAIT_VALUE_64 = 4, + HIP_STREAM_MEM_OP_WRITE_VALUE_64 = 5, + HIP_STREAM_MEM_OP_FLUSH_REMOTE_WRITES = 3, +} HIPstreamBatchMemOpType; + + +typedef union HIPstreamBatchMemOpParams_union { + HIPstreamBatchMemOpType operation; + struct HIPstreamMemOpWaitValueParams_st { + HIPstreamBatchMemOpType operation; + hipDeviceptr_t address; + union { + hipuint32_t value; + hipuint64_t value64; + }; + unsigned int flags; + hipDeviceptr_t alias; + } waitValue; + struct HIPstreamMemOpWriteValueParams_st { + HIPstreamBatchMemOpType operation; + hipDeviceptr_t address; + union { + hipuint32_t value; + hipuint64_t value64; + }; + unsigned int flags; + hipDeviceptr_t alias; + } writeValue; + struct HIPstreamMemOpFlushRemoteWritesParams_st { + HIPstreamBatchMemOpType operation; + unsigned int flags; + } flushRemoteWrites; + hipuint64_t pad[6]; +} HIPstreamBatchMemOpParams; + +typedef enum HIPoccupancy_flags_enum { + hipOccupancyDefault = 0x0, + HIP_OCCUPANCY_DISABLE_CACHING_OVERRIDE = 0x1, +} HIPoccupancy_flags; + +typedef enum hipDeviceAttribute_t { + hipDeviceAttributeCudaCompatibleBegin = 0, + hipDeviceAttributeEccEnabled = hipDeviceAttributeCudaCompatibleBegin, ///< Whether ECC support is enabled. + hipDeviceAttributeAccessPolicyMaxWindowSize, ///< Cuda only. The maximum size of the window policy in bytes. + hipDeviceAttributeAsyncEngineCount, ///< Cuda only. Asynchronous engines number. + hipDeviceAttributeCanMapHostMemory, ///< Whether host memory can be mapped into device address space + hipDeviceAttributeCanUseHostPointerForRegisteredMem,///< Cuda only. Device can access host registered memory + ///< at the same virtual address as the CPU + hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz. + hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in. + hipDeviceAttributeComputePreemptionSupported, ///< Cuda only. Device supports Compute Preemption. + hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels concurrently. + hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access managed memory concurrently with the CPU + hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch + hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices + hipDeviceAttributeDeviceOverlap, ///< Cuda only. Device can concurrently copy memory and execute a kernel. + ///< Deprecated. Use instead asyncEngineCount. + hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly access managed memory on + ///< the device without migration + hipDeviceAttributeGlobalL1CacheSupported, ///< Cuda only. Device supports caching globals in L1 + hipDeviceAttributeHostNativeAtomicSupported, ///< Cuda only. Link between the device and the host supports native atomic operations + hipDeviceAttributeIntegrated, ///< Device is integrated GPU + hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices. + hipDeviceAttributeKernelExecTimeout, ///< Run time limit for kernels executed on the device + hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2 cache. + hipDeviceAttributeLocalL1CacheSupported, ///< caching locals in L1 is supported + hipDeviceAttributeLuid, ///< Cuda only. 8-byte locally unique identifier in 8 bytes. Undefined on TCC and non-Windows platforms + hipDeviceAttributeLuidDeviceNodeMask, ///< Cuda only. Luid device node mask. Undefined on TCC and non-Windows platforms + hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number. + hipDeviceAttributeManagedMemory, ///< Device supports allocating managed memory on this system + hipDeviceAttributeMaxBlocksPerMultiProcessor, ///< Cuda only. Max block size per multiprocessor + hipDeviceAttributeMaxBlockDimX, ///< Max block size in width. + hipDeviceAttributeMaxBlockDimY, ///< Max block size in height. + hipDeviceAttributeMaxBlockDimZ, ///< Max block size in depth. + hipDeviceAttributeMaxGridDimX, ///< Max grid size in width. + hipDeviceAttributeMaxGridDimY, ///< Max grid size in height. + hipDeviceAttributeMaxGridDimZ, ///< Max grid size in depth. + hipDeviceAttributeMaxSurface1D, ///< Maximum size of 1D surface. + hipDeviceAttributeMaxSurface1DLayered, ///< Cuda only. Maximum dimensions of 1D layered surface. + hipDeviceAttributeMaxSurface2D, ///< Maximum dimension (width, height) of 2D surface. + hipDeviceAttributeMaxSurface2DLayered, ///< Cuda only. Maximum dimensions of 2D layered surface. + hipDeviceAttributeMaxSurface3D, ///< Maximum dimension (width, height, depth) of 3D surface. + hipDeviceAttributeMaxSurfaceCubemap, ///< Cuda only. Maximum dimensions of Cubemap surface. + hipDeviceAttributeMaxSurfaceCubemapLayered, ///< Cuda only. Maximum dimension of Cubemap layered surface. + hipDeviceAttributeMaxTexture1DWidth, ///< Maximum size of 1D texture. + hipDeviceAttributeMaxTexture1DLayered, ///< Cuda only. Maximum dimensions of 1D layered texture. + hipDeviceAttributeMaxTexture1DLinear, ///< Maximum number of elements allocatable in a 1D linear texture. + ///< Use cudaDeviceGetTexture1DLinearMaxWidth() instead on Cuda. + hipDeviceAttributeMaxTexture1DMipmap, ///< Cuda only. Maximum size of 1D mipmapped texture. + hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D texture. + hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension hight of 2D texture. + hipDeviceAttributeMaxTexture2DGather, ///< Cuda only. Maximum dimensions of 2D texture if gather operations performed. + hipDeviceAttributeMaxTexture2DLayered, ///< Cuda only. Maximum dimensions of 2D layered texture. + hipDeviceAttributeMaxTexture2DLinear, ///< Cuda only. Maximum dimensions (width, height, pitch) of 2D textures bound to pitched memory. + hipDeviceAttributeMaxTexture2DMipmap, ///< Cuda only. Maximum dimensions of 2D mipmapped texture. + hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D texture. + hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimension height of 3D texture. + hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimension depth of 3D texture. + hipDeviceAttributeMaxTexture3DAlt, ///< Cuda only. Maximum dimensions of alternate 3D texture. + hipDeviceAttributeMaxTextureCubemap, ///< Cuda only. Maximum dimensions of Cubemap texture + hipDeviceAttributeMaxTextureCubemapLayered, ///< Cuda only. Maximum dimensions of Cubemap layered texture. + hipDeviceAttributeMaxThreadsDim, ///< Maximum dimension of a block + hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block. + hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per multiprocessor. + hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies + hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits. + hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz. + hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number. + hipDeviceAttributeMultiGpuBoardGroupID, ///< Cuda only. Unique ID of device group on the same multi-GPU board + hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device. + hipDeviceAttributeName, ///< Device name. + hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently accessing pageable memory + ///< without calling hipHostRegister on it + hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses pageable memory via the host's page tables + hipDeviceAttributePciBusId, ///< PCI Bus ID. + hipDeviceAttributePciDeviceId, ///< PCI Device ID. + hipDeviceAttributePciDomainID, ///< PCI Domain ID. + hipDeviceAttributePersistingL2CacheMaxSize, ///< Cuda11 only. Maximum l2 persisting lines capacity in bytes + hipDeviceAttributeMaxRegistersPerBlock, ///< 32-bit registers available to a thread block. This number is shared + ///< by all thread blocks simultaneously resident on a multiprocessor. + hipDeviceAttributeMaxRegistersPerMultiprocessor, ///< 32-bit registers available per block. + hipDeviceAttributeReservedSharedMemPerBlock, ///< Cuda11 only. Shared memory reserved by CUDA driver per block. + hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in bytes. + hipDeviceAttributeSharedMemPerBlockOptin, ///< Cuda only. Maximum shared memory per block usable by special opt in. + hipDeviceAttributeSharedMemPerMultiprocessor, ///< Cuda only. Shared memory available per multiprocessor. + hipDeviceAttributeSingleToDoublePrecisionPerfRatio, ///< Cuda only. Performance ratio of single precision to double precision. + hipDeviceAttributeStreamPrioritiesSupported, ///< Cuda only. Whether to support stream priorities. + hipDeviceAttributeSurfaceAlignment, ///< Cuda only. Alignment requirement for surfaces + hipDeviceAttributeTccDriver, ///< Cuda only. Whether device is a Tesla device using TCC driver + hipDeviceAttributeTextureAlignment, ///< Alignment requirement for textures + hipDeviceAttributeTexturePitchAlignment, ///< Pitch alignment requirement for 2D texture references bound to pitched memory; + hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes. + hipDeviceAttributeTotalGlobalMem, ///< Global memory available on devicice. + hipDeviceAttributeUnifiedAddressing, ///< Cuda only. An unified address space shared with the host. + hipDeviceAttributeUuid, ///< Cuda only. Unique ID in 16 byte. + hipDeviceAttributeWarpSize, ///< Warp size in threads. + hipDeviceAttributeCudaCompatibleEnd = 9999, + hipDeviceAttributeAmdSpecificBegin = 10000, + hipDeviceAttributeClockInstructionRate = hipDeviceAttributeAmdSpecificBegin, ///< Frequency in khz of the timer used by the device-side "clock*" + hipDeviceAttributeArch, ///< Device architecture + hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory PerMultiprocessor. + hipDeviceAttributeGcnArch, ///< Device gcn architecture + hipDeviceAttributeGcnArchName, ///< Device gcnArch name in 256 bytes + hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register + hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register + hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports cooperative launch on multiple + ///< devices with unmatched functions + hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports cooperative launch on multiple + ///< devices with unmatched grid dimensions + hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports cooperative launch on multiple + ///< devices with unmatched block dimensions + hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports cooperative launch on multiple + ///< devices with unmatched shared memories + hipDeviceAttributeIsLargeBar, ///< Whether it is LargeBar + hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device + hipDeviceAttributeCanUseStreamWaitValue, ///< '1' if Device supports hipStreamWaitValue32() and + ///< hipStreamWaitValue64() , '0' otherwise. + hipDeviceAttributeAmdSpecificEnd = 19999, + hipDeviceAttributeVendorSpecificBegin = 20000, + // Extended attributes for vendors +} hipDeviceAttribute_t; + +typedef struct HIPdevprop_st { + int maxThreadsPerBlock; + int maxThreadsDim[3]; + int maxGridSize[3]; + int sharedMemPerBlock; + int totalConstantMemory; + int SIMDWidth; + int memPitch; + int regsPerBlock; + int clockRate; + int textureAlign; +} HIPdevprop; + +typedef struct { + // 32-bit Atomics + unsigned hasGlobalInt32Atomics : 1; ///< 32-bit integer atomics for global memory. + unsigned hasGlobalFloatAtomicExch : 1; ///< 32-bit float atomic exch for global memory. + unsigned hasSharedInt32Atomics : 1; ///< 32-bit integer atomics for shared memory. + unsigned hasSharedFloatAtomicExch : 1; ///< 32-bit float atomic exch for shared memory. + unsigned hasFloatAtomicAdd : 1; ///< 32-bit float atomic add in global and shared memory. + + // 64-bit Atomics + unsigned hasGlobalInt64Atomics : 1; ///< 64-bit integer atomics for global memory. + unsigned hasSharedInt64Atomics : 1; ///< 64-bit integer atomics for shared memory. + + // Doubles + unsigned hasDoubles : 1; ///< Double-precision floating point. + + // Warp cross-lane operations + unsigned hasWarpVote : 1; ///< Warp vote instructions (__any, __all). + unsigned hasWarpBallot : 1; ///< Warp ballot instructions (__ballot). + unsigned hasWarpShuffle : 1; ///< Warp shuffle operations. (__shfl_*). + unsigned hasFunnelShift : 1; ///< Funnel two words into one with shift&mask caps. + + // Sync + unsigned hasThreadFenceSystem : 1; ///< __threadfence_system. + unsigned hasSyncThreadsExt : 1; ///< __syncthreads_count, syncthreads_and, syncthreads_or. + + // Misc + unsigned hasSurfaceFuncs : 1; ///< Surface functions. + unsigned has3dGrid : 1; ///< Grid and group dims are 3D (rather than 2D). + unsigned hasDynamicParallelism : 1; ///< Dynamic parallelism. +} hipDeviceArch_t; + +typedef struct hipDeviceProp_t { + char name[256]; ///< Device name. + size_t totalGlobalMem; ///< Size of global memory region (in bytes). + size_t sharedMemPerBlock; ///< Size of shared memory region (in bytes). + int regsPerBlock; ///< Registers per block. + int warpSize; ///< Warp size. + int maxThreadsPerBlock; ///< Max work items per work group or workgroup max size. + int maxThreadsDim[3]; ///< Max number of threads in each dimension (XYZ) of a block. + int maxGridSize[3]; ///< Max grid dimensions (XYZ). + int clockRate; ///< Max clock frequency of the multiProcessors in khz. + int memoryClockRate; ///< Max global memory clock frequency in khz. + int memoryBusWidth; ///< Global memory bus width in bits. + size_t totalConstMem; ///< Size of shared memory region (in bytes). + int major; ///< Major compute capability. On HCC, this is an approximation and features may + ///< differ from CUDA CC. See the arch feature flags for portable ways to query + ///< feature caps. + int minor; ///< Minor compute capability. On HCC, this is an approximation and features may + ///< differ from CUDA CC. See the arch feature flags for portable ways to query + ///< feature caps. + int multiProcessorCount; ///< Number of multi-processors (compute units). + int l2CacheSize; ///< L2 cache size. + int maxThreadsPerMultiProcessor; ///< Maximum resident threads per multi-processor. + int computeMode; ///< Compute mode. + int clockInstructionRate; ///< Frequency in khz of the timer used by the device-side "clock*" + ///< instructions. New for HIP. + hipDeviceArch_t arch; ///< Architectural feature flags. New for HIP. + int concurrentKernels; ///< Device can possibly execute multiple kernels concurrently. + int pciDomainID; ///< PCI Domain ID + int pciBusID; ///< PCI Bus ID. + int pciDeviceID; ///< PCI Device ID. + size_t maxSharedMemoryPerMultiProcessor; ///< Maximum Shared Memory Per Multiprocessor. + int isMultiGpuBoard; ///< 1 if device is on a multi-GPU board, 0 if not. + int canMapHostMemory; ///< Check whether HIP can map host memory + int gcnArch; ///< DEPRECATED: use gcnArchName instead + char gcnArchName[256]; ///< AMD GCN Arch Name. + int integrated; ///< APU vs dGPU + int cooperativeLaunch; ///< HIP device supports cooperative launch + int cooperativeMultiDeviceLaunch; ///< HIP device supports cooperative launch on multiple devices + int maxTexture1DLinear; ///< Maximum size for 1D textures bound to linear memory + int maxTexture1D; ///< Maximum number of elements in 1D images + int maxTexture2D[2]; ///< Maximum dimensions (width, height) of 2D images, in image elements + int maxTexture3D[3]; ///< Maximum dimensions (width, height, depth) of 3D images, in image elements + unsigned int* hdpMemFlushCntl; ///< Addres of HDP_MEM_COHERENCY_FLUSH_CNTL register + unsigned int* hdpRegFlushCntl; ///< Addres of HDP_REG_COHERENCY_FLUSH_CNTL register + size_t memPitch; /// Date: Sat, 26 Aug 2023 21:30:48 -0500 Subject: [PATCH 34/44] Revert "[SPIRV] Reduce the number of warps used by subgroup reduce" This reverts commit 31e7635b99a0472913f51ff381365a27a72231dc. --- compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp | 4 +--- .../compiler/Codegen/SPIRV/test/config_default_matvec.mlir | 2 +- .../compiler/Codegen/SPIRV/test/config_default_reduction.mlir | 4 ++-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index ec3bb090ec50f..1579eabf27dc3 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -40,7 +40,6 @@ using llvm::APIntOps::GreatestCommonDivisor; constexpr unsigned numTilesPerSubgroupDimK = 2; constexpr int kMaxVectorNumBits = 128; -constexpr int kPreferredReductionVectorUnrollAmount = 8; namespace mlir { namespace iree_compiler { @@ -1260,8 +1259,7 @@ static LogicalResult setReductionConfig(const spirv::TargetEnv &targetEnv, return failure(); // Let each thread handle `vectorSize` elements. - unsigned vectorSize = - kPreferredReductionVectorUnrollAmount * kMaxVectorNumBits / bitWidth; + unsigned vectorSize = kMaxVectorNumBits / bitWidth; while ((dimSize / vectorSize) % subgroupSize != 0) vectorSize /= 2; diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir index 4cbe804a85224..f60f5a8d4bdbc 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir @@ -140,7 +140,7 @@ hal.executable @i4_dequant_matvec_f32 { // CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK-LABEL: hal.executable.export public @i4_dequant_matvec_f32 // CHECK-SAME: translation_info = #[[$TRANSLATION]] -// CHECK-SAME: workgroup_size = [128 : index, 1 : index, 1 : index] +// CHECK-SAME: workgroup_size = [1024 : index, 1 : index, 1 : index] // CHECK: func.func @i4_dequant_matvec_f32() // CHECK: linalg.generic // CHECK-SAME: lowering_config = #[[$CONFIG]] diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir index 896b735aa02bb..7b08b772ddf1f 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir @@ -47,7 +47,7 @@ hal.executable private @subgroup_reduce_f32 { // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @subgroup_reduce_f32 // CHECK-SAME: translation_info = #[[TRANSLATION]] -// CHECK-SAME: workgroup_size = [16 : index, 1 : index, 1 : index] +// CHECK-SAME: workgroup_size = [128 : index, 1 : index, 1 : index] // CHECK: func.func @subgroup_reduce_f32() // CHECK: linalg.generic // CHECK-SAME: lowering_config = #[[CONFIG]] @@ -105,7 +105,7 @@ hal.executable private @subgroup_reduce_f16 { // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: hal.executable.export public @subgroup_reduce_f16 // CHECK-SAME: translation_info = #[[TRANSLATION]] -// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index] +// CHECK-SAME: workgroup_size = [512 : index, 1 : index, 1 : index] // CHECK: func.func @subgroup_reduce_f16() // CHECK: linalg.generic // CHECK-SAME: lowering_config = #[[CONFIG]] From 8988da2f412fa3e70a453e252c2bf92bb55408f4 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Sun, 27 Aug 2023 11:49:39 -0700 Subject: [PATCH 35/44] [ROCM] Enable WarpReduction on ROCM + Matvec on GPU. --- .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 101 ++++++++++++++++-- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 4 +- 2 files changed, 92 insertions(+), 13 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 10e53b29e542a..b9ad1739b3dc0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -30,6 +30,7 @@ using namespace mlir::iree_compiler; static constexpr unsigned cudaWarpSize = 32; static constexpr StringLiteral kCudaTarget = "cuda"; +static constexpr StringLiteral kRocmTarget = "rocm"; namespace mlir { namespace iree_compiler { llvm::cl::opt clGPUCodegenTransformDialectFileName( @@ -162,11 +163,19 @@ bool isCudaTarget(func::FuncOp entryPoint) { return false; } -static TargetInfo getTargetInfo(func::FuncOp entryPoint) { +bool isRocmTarget(func::FuncOp entryPoint) { + if (auto variantOp = + entryPoint->getParentOfType()) { + IREE::HAL::ExecutableTargetAttr targetAttr = variantOp.getTarget(); + if (auto backend = targetAttr.getBackend()) { + return backend.getValue().str() == kRocmTarget; + } + } + return false; +} + +static TargetInfo getCudaTargetInfo(func::FuncOp entryPoint) { TargetInfo info; - // TODO: fill out target info for other vendors. - if (!isCudaTarget(entryPoint)) - return info; // All the cuda target are assumed to have warp support. info.hasWarpShuffle = true; StringRef targetName = getTargetArch(entryPoint); @@ -190,6 +199,34 @@ static TargetInfo getTargetInfo(func::FuncOp entryPoint) { return info; } +// TODO: Plumb in WarpSize into TargetInfo for wave64 systems. +static TargetInfo getRocmTargetInfo(func::FuncOp entryPoint) { + TargetInfo info; + StringRef targetName = getTargetArch(entryPoint); + // If no target name is set assume all the features are off. + if (targetName == "") + return info; + if (!StringRef(targetName).starts_with("gfx")) { + entryPoint.emitError("unknown target name ") << targetName; + return info; + } + // Assumes all gfx has warp shuffle. + info.hasWarpShuffle = true; + // TODO: Check and enable for WMMA once pipeline is available. + return info; +} + +static TargetInfo getTargetInfo(func::FuncOp entryPoint) { + TargetInfo info; + // TODO: fill out target info for other vendors. + if (isCudaTarget(entryPoint)) { + info = getCudaTargetInfo(entryPoint); + } else if (isRocmTarget(entryPoint)) { + info = getRocmTargetInfo(entryPoint); + } + return info; +} + static bool supportsTensorCore(func::FuncOp entryPoint, linalg::LinalgOp op, const TargetInfo &targetInfo) { // Limit tensor core pipeline to matmul as not all combinations of transpose @@ -254,6 +291,20 @@ static LogicalResult setContractConfig(func::FuncOp entryPoint, if (!linalg::isaContractionOpInterface(op) || op.getNumParallelLoops() < 2) { return failure(); } + + // Also exclude the case of matvec, which has only one non-unit parallel dim. + // They should go down different pipelines. + int nonUnitParallelDimCount = 0; + SmallVector bounds = op.getStaticLoopRanges(); + SmallVector kinds = op.getIteratorTypesArray(); + for (auto [kind, bound] : llvm::zip(kinds, bounds)) { + if (kind == utils::IteratorType::parallel) + nonUnitParallelDimCount += bound != 1; + } + if (!isa(op) && + nonUnitParallelDimCount == 1) + return failure(); + // Don't consider operations that don't have a broadcast, those should go // through reductions. if (llvm::any_of(op.getIndexingMapsArray(), @@ -750,13 +801,24 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, } SmallVector reductionDims; op.getReductionDims(reductionDims); - if (reductionDims.size() != 1 || reductionDims[0] != op.getNumLoops() - 1) + if (reductionDims.empty()) return failure(); + + // Make sure reduction dimensions are the innermost ones. + for (int i = 0; i < reductionDims.size(); ++i) { + if (reductionDims[reductionDims.size() - 1 - i] != + op.getNumLoops() - 1 - i) { + return failure(); + } + } + if (op.getRegionOutputArgs().size() != 1) return failure(); + // Only support projected permutation, this could be extended to projected // permutated with broadcast. + if (llvm::any_of(op.getDpsInputOperands(), [&](OpOperand *input) { return !op.getMatchingIndexingMap(input).isProjectedPermutation(); })) @@ -779,8 +841,12 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, if (!foundSingleReductionOutput) return failure(); - std::optional dimSize = getLinalgDimSize(op, reductionDims[0]); - if (!dimSize || *dimSize % cudaWarpSize != 0) + + SmallVector bounds = op.getStaticLoopRanges(); + int64_t dimSize = 1; + for (int64_t dim : reductionDims) + dimSize *= bounds[dim]; + if (dimSize % cudaWarpSize != 0) return failure(); const Type elementType = @@ -793,14 +859,15 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8) return failure(); + const unsigned largestLoadSizeInBits = 128; unsigned vectorSize = largestLoadSizeInBits / bitWidth; - while ((*dimSize / vectorSize) % cudaWarpSize != 0) + while ((dimSize / vectorSize) % cudaWarpSize != 0) vectorSize /= 2; // TODO: Add reduction tiling to handle larger reductions. const int64_t maxWorkgroupSize = 1024; - int64_t groupSize = *dimSize / vectorSize; + int64_t groupSize = dimSize / vectorSize; if (groupSize > maxWorkgroupSize) { groupSize = llvm::APIntOps::GreatestCommonDivisor( {64, uint64_t(groupSize)}, {64, uint64_t(maxWorkgroupSize)}) @@ -813,8 +880,20 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1; // Tile all the parallel dimension to 1. SmallVector workgroupTileSizes(numLoops, 1); - SmallVector reductionTileSizes(numLoops, 0); - reductionTileSizes.push_back(groupSize * vectorSize); + SmallVector reductionTileSizes(op.getNumLoops(), 0); + int64_t remaingGroupSize = groupSize; + for (int i = reductionDims.size() - 1; i >= 0; --i) { + int64_t dim = reductionDims[i]; + int64_t bound = bounds[dim]; + if (i == reductionDims.size() - 1) + bound /= vectorSize; + APInt size = llvm::APIntOps::GreatestCommonDivisor( + {64, uint64_t(remaingGroupSize)}, {64, uint64_t(bound)}); + reductionTileSizes[dim] = size.getSExtValue(); + if (i == reductionDims.size() - 1) + reductionTileSizes[dim] *= vectorSize; + remaingGroupSize /= size.getSExtValue(); + } TileSizesListType tileSizes; tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level tileSizes.emplace_back(std::move(reductionTileSizes)); // reduction level diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index d4503cc806eeb..ccde7e9eef8da 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -374,6 +374,8 @@ void addGPUTransposePassPipeline(OpPassManager &pm) { void addGPUWarpReductionPassPipeline(OpPassManager &pm) { tileAndDistributeToWorkgroup(pm); auto &nestedModulePM = pm.nest(); + nestedModulePM.addNestedPass( + createRematerializeParallelOpsPass()); nestedModulePM.addNestedPass(createCanonicalizerPass()); nestedModulePM.addNestedPass(createGPUTileReductionPass()); nestedModulePM.addNestedPass(createCanonicalizerPass()); @@ -581,8 +583,6 @@ void addGPUTransformDialectPasses(OpPassManager &passManager) { void buildLLVMGPUTransformPassPipeline(OpPassManager &pm, bool useROCM) { addCommonTargetExecutablePreprocessingPasses(pm.nest()); - pm.nest().addNestedPass( - createRematerializeParallelOpsPass()); pm.addPass(createLLVMGPULowerExecutableTargetPass()); OpPassManager &nestedModulePM = pm.nest(); //===--------------------------------------------------------------------===// From 883fad4da4bb8bc9e3f9908cba28f64db18397bf Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 29 Aug 2023 14:51:46 -0500 Subject: [PATCH 36/44] [ROCM] Replace rocm sdk ld.lld with iree-lld for compile-time linkage. --- .../HAL/Target/ROCM/ROCMTargetUtils.cpp | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTargetUtils.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTargetUtils.cpp index c3064c9a5a6b5..f325efa9df866 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTargetUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTargetUtils.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.h" +#include "iree/compiler/Utils/ToolUtils.h" #include "llvm/IR/Module.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Linker/Linker.h" @@ -161,39 +162,38 @@ std::string createHsaco(const std::string isa, StringRef name) { // Invoke lld. Expect a true return value from lld. // Searching for LLD - std::string lldProgram; - std::string toolName = "ld.lld"; - if (llvm::sys::fs::exists(toolName)) { - llvm::SmallString<256> absolutePath(toolName); - llvm::sys::fs::make_absolute(absolutePath); - lldProgram = std::string(absolutePath); - } else { - // Next search the environment path. - if (auto result = llvm::sys::Process::FindInEnvPath("PATH", toolName)) { - lldProgram = std::string(*result); - } - } + const SmallVector &toolNames{"iree-lld"}; + std::string lldProgram = findTool(toolNames); if (lldProgram.empty()) { llvm::WithColor::error(llvm::errs(), name) - << "unable to find ld.lld in PATH\n"; + << "unable to find iree-lld.\n"; return {}; } // Setting Up LLD Args + if ( lldProgram.front() == '"' ) { + lldProgram.erase( 0, 1 ); // erase the first character + lldProgram.erase( lldProgram.size() - 1 ); // erase the last character + } +#if defined(_WIN32) + llvm::StringRef lldName = "iree-lld.exe"; +#else + llvm::StringRef lldName = "iree-lld"; +#endif // _WIN32 std::vector lldArgs{ - llvm::StringRef("ld.lld"), llvm::StringRef("-flavor"), + lldName, llvm::StringRef("-flavor"), llvm::StringRef("gnu"), llvm::StringRef("-shared"), tempIsaBinaryFilename.str(), llvm::StringRef("-o"), tempHsacoFilename.str(), }; - + // Executing LLD std::string errorMessage; int lldResult = llvm::sys::ExecuteAndWait( - lldProgram, llvm::ArrayRef(lldArgs), std::nullopt, {}, 5, + lldProgram, llvm::ArrayRef(lldArgs), llvm::StringRef("LLD_VERSION=IREE"), {}, 5, 0, &errorMessage); if (lldResult) { llvm::WithColor::error(llvm::errs(), name) - << "ld.lld execute fail:" << errorMessage << "Error Code:" << lldResult + << "iree-lld execute fail:" << errorMessage << "Error Code:" << lldResult << "\n"; return {}; } From f8351ee49ef4d46556b1442c245f727762a05b15 Mon Sep 17 00:00:00 2001 From: stanley-nod Date: Thu, 14 Sep 2023 00:59:20 -0700 Subject: [PATCH 37/44] [experimental][ROCM] Stream Command Buffer and Enable Shared mem --- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 4 +- .../Dialect/HAL/Target/ROCM/ROCMTarget.cpp | 10 +- experimental/rocm/CMakeLists.txt | 5 + experimental/rocm/api.h | 41 ++ experimental/rocm/context_wrapper.h | 1 + experimental/rocm/direct_command_buffer.c | 6 +- experimental/rocm/dynamic_symbol_tables.h | 3 + experimental/rocm/hip_headers.h | 6 + experimental/rocm/native_executable.c | 27 + experimental/rocm/native_executable.h | 1 + .../rocm/registration/driver_module.c | 23 +- experimental/rocm/rocm_device.c | 93 ++- experimental/rocm/rocm_device.h | 1 + experimental/rocm/rocm_driver.c | 10 +- experimental/rocm/stream_command_buffer.c | 562 ++++++++++++++++++ experimental/rocm/stream_command_buffer.h | 49 ++ .../src/iree/schemas/rocm_executable_def.fbs | 3 + 17 files changed, 825 insertions(+), 20 deletions(-) create mode 100644 experimental/rocm/stream_command_buffer.c create mode 100644 experimental/rocm/stream_command_buffer.h diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index ccde7e9eef8da..834bdeb3877e3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -467,7 +467,7 @@ static void addLowerAndOptimzeAddressComputation(OpPassManager &pm) { pm.addPass(createExtractAddressComputationGPUPass()); pm.addNestedPass(memref::createExpandOpsPass()); pm.addPass(memref::createFoldMemRefAliasOpsPass()); - pm.addPass(memref::createExpandStridedMetadataPass()); + pm.addPass(createIREEExpandStridedMetadataPass()); // Hoist loop invariant variables to give decompose affine pass the right loop // dependencies. pm.addPass(createLoopInvariantCodeMotionPass()); @@ -543,7 +543,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &pm, bool useROCM) { pm.addNestedPass(memref::createExpandOpsPass()); pm.addPass(memref::createFoldMemRefAliasOpsPass()); - pm.addPass(memref::createExpandStridedMetadataPass()); + pm.addPass(createIREEExpandStridedMetadataPass()); pm.addPass(createEmulateNarrowTypePass()); pm.addPass(createLowerAffinePass()); pm.addPass(createCanonicalizerPass()); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp index cf8926f2f4062..0b99e641fd963 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp @@ -149,6 +149,7 @@ class ROCMTargetBackend final : public TargetBackend { exportOps[op.getSymName()] = op; } std::vector> workgroupSizes; + SmallVector workgroupLocalMemories; for (auto func : innerModuleOp.getOps()) { int32_t flatWgSize = 1; auto *llvmFunc = llvmModule->getFunction(func.getName()); @@ -166,6 +167,11 @@ class ROCMTargetBackend final : public TargetBackend { workgroupSize = {1, 1, 1}; } workgroupSizes.push_back(workgroupSize); + uint32_t workgroupLocalMemory = 0; + if (auto workgroupLocalMemoryAttr = exportOp.getWorkgroupLocalMemory()) { + workgroupLocalMemory = workgroupLocalMemoryAttr->getSExtValue(); + } + workgroupLocalMemories.push_back(workgroupLocalMemory); // For GPU kernels, // 1. Insert AMDGPU_KERNEL calling convention. // 2. Insert amdgpu-flat-workgroup-size(1, 256) attribute. @@ -231,9 +237,11 @@ class ROCMTargetBackend final : public TargetBackend { ++blockSizes; } auto blockSizesRef = iree_hal_rocm_BlockSizeDef_vec_end(builder); - + auto workgroupLocalMemoriesRef = + builder.createInt32Vec(workgroupLocalMemories); iree_hal_rocm_ExecutableDef_entry_points_add(builder, entryPointsRef); iree_hal_rocm_ExecutableDef_block_sizes_add(builder, blockSizesRef); + iree_hal_rocm_ExecutableDef_shared_memory_size_add(builder, workgroupLocalMemoriesRef); iree_hal_rocm_ExecutableDef_hsaco_image_add(builder, hsacoRef); iree_hal_rocm_ExecutableDef_end_as_root(builder); diff --git a/experimental/rocm/CMakeLists.txt b/experimental/rocm/CMakeLists.txt index 8b3c3b9eb2e2a..bc4ccfafa2643 100644 --- a/experimental/rocm/CMakeLists.txt +++ b/experimental/rocm/CMakeLists.txt @@ -63,6 +63,8 @@ iree_cc_library( "pipeline_layout.h" "status_util.c" "status_util.h" + "stream_command_buffer.c" + "stream_command_buffer.h" "tracing.c" "tracing.h" INCLUDES @@ -78,8 +80,11 @@ iree_cc_library( iree::base::internal::synchronization iree::hal iree::hal::utils::buffer_transfer + iree::hal::utils::collective_batch + iree::hal::utils::deferred_command_buffer iree::hal::utils::file_transfer iree::hal::utils::memory_file + iree::hal::utils::resource_set iree::hal::utils::semaphore_base iree::schemas::rocm_executable_def_c_fbs COPTS diff --git a/experimental/rocm/api.h b/experimental/rocm/api.h index 68fa1913bf2f9..7949ac407afa3 100644 --- a/experimental/rocm/api.h +++ b/experimental/rocm/api.h @@ -16,6 +16,46 @@ extern "C" { #endif // __cplusplus +//===----------------------------------------------------------------------===// +// iree_hal_rocm_device_t +//===----------------------------------------------------------------------===// + +// Defines how command buffers are recorded and executed. +typedef enum iree_hal_rocm_command_buffer_mode_e { + // Command buffers are recorded into ROCM null stream. + IREE_HAL_ROCM_COMMAND_BUFFER_MODE_DIRECT = 0, + // Command buffers are directly issued against ROCM stream. + IREE_HAL_ROCM_COMMAND_BUFFER_MODE_STREAM = 1, +} iree_hal_rocm_command_buffer_mode_t; + +// Parameters configuring an iree_hal_rocm_device_t. +// Must be initialized with iree_hal_rocm_device_params_initialize prior to use. +typedef struct iree_hal_rocm_device_params_t { + + // Total size of each block in the device shared block pool. + // Larger sizes will lower overhead and ensure the heap isn't hit for + // transient allocations while also increasing memory consumption. + iree_host_size_t arena_block_size; + + // Specifies how command buffers are recorded and executed. + iree_hal_rocm_command_buffer_mode_t command_buffer_mode; + + // Enables tracing of command buffers when IREE tracing is enabled. + // May take advantage of additional extensions for more accurate timing or + // hardware-specific performance counters. + // + // NOTE: tracing has a non-trivial overhead and will skew the timing of + // submissions and introduce false barriers between dispatches. Use this to + // identify slow dispatches and refine from there; be wary of whole-program + // tracing with this enabled. + bool stream_tracing; + +} iree_hal_rocm_device_params_t; + +// Initializes |out_params| to default values. +IREE_API_EXPORT void iree_hal_rocm_device_params_initialize( + iree_hal_rocm_device_params_t* out_params); + //===----------------------------------------------------------------------===// // iree_hal_rocm_driver_t //===----------------------------------------------------------------------===// @@ -35,6 +75,7 @@ IREE_API_EXPORT void iree_hal_rocm_driver_options_initialize( // |out_driver| must be released by the caller (see |iree_hal_driver_release|). IREE_API_EXPORT iree_status_t iree_hal_rocm_driver_create( iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* default_params, const iree_hal_rocm_driver_options_t *options, iree_allocator_t host_allocator, iree_hal_driver_t **out_driver); diff --git a/experimental/rocm/context_wrapper.h b/experimental/rocm/context_wrapper.h index 2c51424640110..819b67a2fad2e 100644 --- a/experimental/rocm/context_wrapper.h +++ b/experimental/rocm/context_wrapper.h @@ -14,6 +14,7 @@ // Structure to wrap all objects constant within a context. This makes it // simpler to pass it to the different objects and saves memory. typedef struct iree_hal_rocm_context_wrapper_t { + hipDevice_t rocm_device; hipCtx_t rocm_context; iree_allocator_t host_allocator; iree_hal_rocm_dynamic_symbols_t *syms; diff --git a/experimental/rocm/direct_command_buffer.c b/experimental/rocm/direct_command_buffer.c index 5a8f66057dedb..637142944413d 100644 --- a/experimental/rocm/direct_command_buffer.c +++ b/experimental/rocm/direct_command_buffer.c @@ -213,7 +213,7 @@ static iree_status_t iree_hal_rocm_direct_command_buffer_fill_buffer( ROCM_RETURN_IF_ERROR( command_buffer->context->syms, hipMemsetD8Async(dst, *(const uint8_t*)(pattern), num_elements, 0), - "hipMemsetD*Async"); + "hipMemsetD8Async"); break; } default: @@ -371,8 +371,8 @@ static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch( command_buffer->context->syms, hipModuleLaunchKernel(kernel_params.function, workgroup_x, workgroup_y, workgroup_z, kernel_params.block_size[0], - kernel_params.block_size[1], - kernel_params.block_size[2], 0, 0, + kernel_params.block_size[1], kernel_params.block_size[2], + kernel_params.shared_memory_size, 0, command_buffer->current_descriptor, NULL), "hipModuleLaunchKernel"); diff --git a/experimental/rocm/dynamic_symbol_tables.h b/experimental/rocm/dynamic_symbol_tables.h index 318374d9c44ab..5cc1ea62a248e 100644 --- a/experimental/rocm/dynamic_symbol_tables.h +++ b/experimental/rocm/dynamic_symbol_tables.h @@ -30,6 +30,7 @@ RC_PFN_DECL(hipMemsetD32Async, void *, int, size_t, hipStream_t) RC_PFN_DECL(hipMemsetD16Async, void *, short, size_t, hipStream_t) RC_PFN_DECL(hipMemsetD8Async, void *, char, size_t, hipStream_t) RC_PFN_DECL(hipMemcpy, void *, const void *, size_t, hipMemcpyKind) +RC_PFN_DECL(hipMemcpyHtoDAsync, hipDeviceptr_t, void *, size_t, hipStream_t) RC_PFN_DECL(hipMemcpyAsync, void *, const void *, size_t, hipMemcpyKind, hipStream_t) RC_PFN_DECL(hipMalloc, void **, size_t) @@ -53,3 +54,5 @@ RC_PFN_DECL(hipEventElapsedTime, float *, hipEvent_t, hipEvent_t) RC_PFN_DECL(hipEventQuery, hipEvent_t) RC_PFN_DECL(hipEventRecord, hipEvent_t, hipStream_t) RC_PFN_DECL(hipEventSynchronize, hipEvent_t) +RC_PFN_DECL(hipDeviceGetAttribute, int *, hipDeviceAttribute_t, int) +RC_PFN_DECL(hipFuncSetAttribute, const void *, hipFuncAttribute, int) diff --git a/experimental/rocm/hip_headers.h b/experimental/rocm/hip_headers.h index 28bcd5667482b..15193c17e9622 100644 --- a/experimental/rocm/hip_headers.h +++ b/experimental/rocm/hip_headers.h @@ -434,6 +434,12 @@ typedef enum hipDeviceAttribute_t { // Extended attributes for vendors } hipDeviceAttribute_t; +typedef enum hipFuncAttribute { + hipFuncAttributeMaxDynamicSharedMemorySize = 8, + hipFuncAttributePreferredSharedMemoryCarveout = 9, + hipFuncAttributeMax, +} hipFuncAttribute; + typedef struct HIPdevprop_st { int maxThreadsPerBlock; int maxThreadsDim[3]; diff --git a/experimental/rocm/native_executable.c b/experimental/rocm/native_executable.c index f82dcf5b78893..ffe7fc77d507b 100644 --- a/experimental/rocm/native_executable.c +++ b/experimental/rocm/native_executable.c @@ -61,6 +61,8 @@ iree_status_t iree_hal_rocm_native_executable_create( iree_hal_rocm_ExecutableDef_entry_points_get(executable_def); iree_hal_rocm_BlockSizeDef_vec_t block_sizes_vec = iree_hal_rocm_ExecutableDef_block_sizes_get(executable_def); + flatbuffers_uint32_vec_t shared_memory_sizes = + iree_hal_rocm_ExecutableDef_shared_memory_size_get(executable_def); iree_host_size_t entry_count = flatbuffers_string_vec_len(entry_points_vec); // Calculate the total number of characters across all entry point names. This @@ -111,6 +113,30 @@ iree_status_t iree_hal_rocm_native_executable_create( entry_name); break; } + + int32_t max_shared_mem = 0; + status = ROCM_RESULT_TO_STATUS( + context->syms, + hipDeviceGetAttribute( + &max_shared_mem, + hipDeviceAttributeMaxSharedMemoryPerBlock, + context->rocm_device), + "hipDeviceGetAttribute"); + if (!iree_status_is_ok(status)) break; + if (shared_memory_sizes[i] > max_shared_mem) { + status = iree_make_status(IREE_STATUS_INTERNAL, + "ROCM driver error: Requested shared memory " + "size of %d larger than allowed size of %d", + shared_memory_sizes[i], max_shared_mem); + } else if (shared_memory_sizes[i] != 0){ + status = ROCM_RESULT_TO_STATUS( + context->syms, + hipFuncSetAttribute(function, + HIP_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_memory_sizes[i]), + "hipFuncSetAttribute"); + } + // Package required parameters for kernel launches for each entry point. iree_hal_rocm_kernel_params_t* params = &executable->entry_points[i]; params->layout = executable_params->pipeline_layouts[i]; @@ -119,6 +145,7 @@ iree_status_t iree_hal_rocm_native_executable_create( params->block_size[0] = block_sizes_vec[i].x; params->block_size[1] = block_sizes_vec[i].y; params->block_size[2] = block_sizes_vec[i].z; + params->shared_memory_size = shared_memory_sizes[i]; // Stash the entry point name in the string table for use when tracing. IREE_TRACE({ iree_host_size_t entry_name_length = diff --git a/experimental/rocm/native_executable.h b/experimental/rocm/native_executable.h index 9789e23e5bf18..0c229a03b5253 100644 --- a/experimental/rocm/native_executable.h +++ b/experimental/rocm/native_executable.h @@ -22,6 +22,7 @@ typedef struct iree_hal_rocm_kernel_params_t { iree_hal_pipeline_layout_t* layout; hipFunction_t function; uint32_t block_size[3]; + uint32_t shared_memory_size; IREE_TRACE(iree_string_view_t function_name;) IREE_TRACE(iree_string_view_t source_filename;) IREE_TRACE(uint32_t source_line;) diff --git a/experimental/rocm/registration/driver_module.c b/experimental/rocm/registration/driver_module.c index fcdadfe3c112e..f1e180a91803b 100644 --- a/experimental/rocm/registration/driver_module.c +++ b/experimental/rocm/registration/driver_module.c @@ -11,6 +11,19 @@ #include "experimental/rocm/api.h" #include "iree/base/api.h" +#include "iree/base/internal/flags.h" + +// Force using ROCM streams until we support command buffer caching to avoid the +// overhead of graph creation. +IREE_FLAG( + bool, rocm_use_streams, true, + "Use ROCM streams for executing command buffers (instead of graphs)."); + +IREE_FLAG( + bool, rocm_tracing, true, + "Enables tracing of stream events when Tracy instrumentation is enabled.\n" + "Severely impacts benchmark timings and should only be used when\n" + "analyzing dispatch timings."); static iree_status_t iree_hal_rocm_driver_factory_enumerate( void *self, iree_host_size_t *out_driver_info_count, @@ -36,10 +49,18 @@ static iree_status_t iree_hal_rocm_driver_factory_try_create( (int)driver_name.size, driver_name.data); } IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_rocm_device_params_t default_params; + iree_hal_rocm_device_params_initialize(&default_params); + if (FLAG_rocm_use_streams) { + default_params.command_buffer_mode = + IREE_HAL_ROCM_COMMAND_BUFFER_MODE_STREAM; + } + default_params.stream_tracing = FLAG_rocm_tracing; iree_hal_rocm_driver_options_t driver_options; iree_hal_rocm_driver_options_initialize(&driver_options); iree_status_t status = iree_hal_rocm_driver_create( - driver_name, &driver_options, host_allocator, out_driver); + driver_name, &default_params, &driver_options, host_allocator, out_driver); IREE_TRACE_ZONE_END(z0); return status; } diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c index 8d563aec72e99..ba52d2ada00bd 100644 --- a/experimental/rocm/rocm_device.c +++ b/experimental/rocm/rocm_device.c @@ -19,9 +19,11 @@ #include "experimental/rocm/rocm_allocator.h" #include "experimental/rocm/rocm_event.h" #include "experimental/rocm/status_util.h" +#include "experimental/rocm//stream_command_buffer.h" #include "experimental/rocm/tracing.h" #include "iree/base/internal/arena.h" #include "iree/hal/utils/buffer_transfer.h" +#include "iree/hal/utils/deferred_command_buffer.h" #include "iree/hal/utils/file_transfer.h" #include "iree/hal/utils/memory_file.h" @@ -41,6 +43,9 @@ typedef struct iree_hal_rocm_device_t { // to ensure the symbols remains valid. iree_hal_driver_t* driver; + // Parameters used to control device behavior. + iree_hal_rocm_device_params_t params; + hipDevice_t device; // TODO: support multiple streams. @@ -51,6 +56,10 @@ typedef struct iree_hal_rocm_device_t { // Optional provider used for creating/configuring collective channels. iree_hal_channel_provider_t* channel_provider; + + // Cache of the direct stream command buffer initialized when in stream mode. + // TODO: have one cached per stream once there are multiple streams. + iree_hal_command_buffer_t* stream_command_buffer; } iree_hal_rocm_device_t; static const iree_hal_device_vtable_t iree_hal_rocm_device_vtable; @@ -61,11 +70,21 @@ static iree_hal_rocm_device_t* iree_hal_rocm_device_cast( return (iree_hal_rocm_device_t*)base_value; } +IREE_API_EXPORT void iree_hal_rocm_device_params_initialize( + iree_hal_rocm_device_params_t* out_params) { + memset(out_params, 0, sizeof(*out_params)); + out_params->arena_block_size = 32*1024; + out_params->command_buffer_mode = IREE_HAL_ROCM_COMMAND_BUFFER_MODE_DIRECT; + out_params->stream_tracing = false; +} + static void iree_hal_rocm_device_destroy(iree_hal_device_t* base_device) { iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_command_buffer_release(device->stream_command_buffer); + // There should be no more buffers live that use the allocator. iree_hal_allocator_release(device->device_allocator); @@ -76,6 +95,8 @@ static void iree_hal_rocm_device_destroy(iree_hal_device_t* base_device) { ROCM_IGNORE_ERROR(device->context_wrapper.syms, hipStreamDestroy(device->stream)); + iree_arena_block_pool_deinitialize(&device->block_pool); + // Finally, destroy the device. iree_hal_driver_release(device->driver); @@ -86,9 +107,9 @@ static void iree_hal_rocm_device_destroy(iree_hal_device_t* base_device) { static iree_status_t iree_hal_rocm_device_create_internal( iree_hal_driver_t* driver, iree_string_view_t identifier, - hipDevice_t rocm_device, hipStream_t stream, hipCtx_t context, - iree_hal_rocm_dynamic_symbols_t* syms, iree_allocator_t host_allocator, - iree_hal_device_t** out_device) { + const iree_hal_rocm_device_params_t* params, hipDevice_t rocm_device, + hipStream_t stream, hipCtx_t context, iree_hal_rocm_dynamic_symbols_t* syms, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { iree_hal_rocm_device_t* device = NULL; iree_host_size_t total_size = sizeof(*device) + identifier.size; IREE_RETURN_IF_ERROR( @@ -100,19 +121,36 @@ static iree_status_t iree_hal_rocm_device_create_internal( uint8_t* buffer_ptr = (uint8_t*)device + sizeof(*device); buffer_ptr += iree_string_view_append_to_buffer( identifier, &device->identifier, (char*)buffer_ptr); + device->params = *params; device->device = rocm_device; device->stream = stream; + device->context_wrapper.rocm_device = rocm_device; device->context_wrapper.rocm_context = context; device->context_wrapper.host_allocator = host_allocator; + iree_arena_block_pool_initialize(params->arena_block_size, host_allocator, + &device->block_pool); device->context_wrapper.syms = syms; // Enable tracing for the (currently only) stream - no-op if disabled. - iree_status_t status = iree_hal_rocm_tracing_context_allocate( + iree_status_t status = iree_ok_status(); + if (device->params.stream_tracing) { + status = iree_hal_rocm_tracing_context_allocate( &device->context_wrapper, device->identifier, stream, &device->block_pool, host_allocator, &device->tracing_context); + } if (iree_status_is_ok(status)) { status = iree_hal_rocm_allocator_create(&device->context_wrapper, &device->device_allocator); } + if (iree_status_is_ok(status) && + params->command_buffer_mode == IREE_HAL_ROCM_COMMAND_BUFFER_MODE_STREAM) { + status = iree_hal_rocm_stream_command_buffer_create( + (iree_hal_device_t*)device, &device->context_wrapper, + device->tracing_context, + IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION | + IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED, + IREE_HAL_COMMAND_CATEGORY_ANY, /*binding_capacity=*/0, device->stream, + &device->block_pool, &device->stream_command_buffer); + } if (iree_status_is_ok(status)) { *out_device = (iree_hal_device_t*)device; } else { @@ -123,10 +161,12 @@ static iree_status_t iree_hal_rocm_device_create_internal( iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* params, iree_hal_rocm_dynamic_symbols_t* syms, hipDevice_t device, iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(params); IREE_TRACE_ZONE_BEGIN(z0); hipCtx_t context; IREE_RETURN_AND_END_ZONE_IF_ERROR( @@ -136,8 +176,8 @@ iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t* driver, syms, hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); if (iree_status_is_ok(status)) { - status = iree_hal_rocm_device_create_internal(driver, identifier, device, - stream, context, syms, + status = iree_hal_rocm_device_create_internal(driver, identifier, params, + device, stream, context, syms, host_allocator, out_device); } if (!iree_status_is_ok(status)) { @@ -224,10 +264,21 @@ static iree_status_t iree_hal_rocm_device_create_command_buffer( iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) { iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); - return iree_hal_rocm_direct_command_buffer_create( - base_device, &device->context_wrapper, device->tracing_context, mode, - command_categories, queue_affinity, binding_capacity, &device->block_pool, - out_command_buffer); + switch (device->params.command_buffer_mode) { + case IREE_HAL_ROCM_COMMAND_BUFFER_MODE_DIRECT: + return iree_hal_rocm_direct_command_buffer_create( + base_device, &device->context_wrapper, device->tracing_context, mode, + command_categories, queue_affinity, binding_capacity, &device->block_pool, + out_command_buffer); + case IREE_HAL_ROCM_COMMAND_BUFFER_MODE_STREAM: + return iree_hal_deferred_command_buffer_create( + base_device, mode, command_categories, binding_capacity, + &device->block_pool, iree_hal_device_host_allocator(base_device), + out_command_buffer); + default: + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid command buffer mode"); + } } static iree_status_t iree_hal_rocm_device_create_descriptor_set_layout( @@ -381,8 +432,28 @@ static iree_status_t iree_hal_rocm_device_queue_execute( // synchronizes after every submit. // TODO(raikonenfnu): currently run on default/null stream, when cmd buffer // stream work with device->stream, we'll change + for (iree_host_size_t i = 0; i < command_buffer_count; i++) { + iree_hal_command_buffer_t* command_buffer = command_buffers[i]; + if (iree_hal_rocm_stream_command_buffer_isa(command_buffer)) { + // Nothing to do for an inline command buffer; all the work has already + // been submitted. When we support semaphores we'll still need to signal + // their completion but do not have to worry about any waits: if there + // were waits we wouldn't have been able to execute inline! + } else if (iree_hal_rocm_direct_command_buffer_isa(command_buffer)) { + IREE_TRACE_ZONE_BEGIN_NAMED(z0, "hipStreamSynchronize"); + ROCM_RETURN_IF_ERROR(device->context_wrapper.syms, hipStreamSynchronize(0), + "hipStreamSynchronize"); + iree_hal_rocm_tracing_context_collect(device->tracing_context); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); + } else { + IREE_RETURN_IF_ERROR(iree_hal_deferred_command_buffer_apply( + command_buffers[i], device->stream_command_buffer, + iree_hal_buffer_binding_table_empty())); + } + } IREE_TRACE_ZONE_BEGIN_NAMED(z0, "hipStreamSynchronize"); - ROCM_RETURN_IF_ERROR(device->context_wrapper.syms, hipStreamSynchronize(0), + ROCM_RETURN_IF_ERROR(device->context_wrapper.syms, hipStreamSynchronize(device->stream), "hipStreamSynchronize"); iree_hal_rocm_tracing_context_collect(device->tracing_context); IREE_TRACE_ZONE_END(z0); diff --git a/experimental/rocm/rocm_device.h b/experimental/rocm/rocm_device.h index 083f4c7cddb66..7abd4e67ce365 100644 --- a/experimental/rocm/rocm_device.h +++ b/experimental/rocm/rocm_device.h @@ -19,6 +19,7 @@ extern "C" { // Creates a device that owns and manages its own hipContext. iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* params, iree_hal_rocm_dynamic_symbols_t* syms, hipDevice_t device, iree_allocator_t host_allocator, diff --git a/experimental/rocm/rocm_driver.c b/experimental/rocm/rocm_driver.c index e284554c46e6b..174e8e108ddcf 100644 --- a/experimental/rocm/rocm_driver.c +++ b/experimental/rocm/rocm_driver.c @@ -21,6 +21,7 @@ typedef struct iree_hal_rocm_driver_t { // We allow overriding so that multiple ROCM versions can be exposed in the // same process. iree_string_view_t identifier; + iree_hal_rocm_device_params_t default_params; int default_device_index; // ROCM symbols. iree_hal_rocm_dynamic_symbols_t syms; @@ -45,6 +46,7 @@ IREE_API_EXPORT void iree_hal_rocm_driver_options_initialize( static iree_status_t iree_hal_rocm_driver_create_internal( iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* default_params, const iree_hal_rocm_driver_options_t* options, iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { iree_hal_rocm_driver_t* driver = NULL; @@ -56,6 +58,8 @@ static iree_status_t iree_hal_rocm_driver_create_internal( iree_string_view_append_to_buffer( identifier, &driver->identifier, (char*)driver + total_size - identifier.size); + memcpy(&driver->default_params, default_params, + sizeof(driver->default_params)); driver->default_device_index = options->default_device_index; iree_status_t status = iree_hal_rocm_dynamic_symbols_initialize(host_allocator, &driver->syms); @@ -80,14 +84,16 @@ static void iree_hal_rocm_driver_destroy(iree_hal_driver_t* base_driver) { IREE_API_EXPORT iree_status_t iree_hal_rocm_driver_create( iree_string_view_t identifier, + const iree_hal_rocm_device_params_t* default_params, const iree_hal_rocm_driver_options_t* options, iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(default_params); IREE_ASSERT_ARGUMENT(options); IREE_ASSERT_ARGUMENT(out_driver); IREE_TRACE_ZONE_BEGIN(z0); iree_status_t status = iree_hal_rocm_driver_create_internal( - identifier, options, host_allocator, out_driver); + identifier, default_params, options, host_allocator, out_driver); IREE_TRACE_ZONE_END(z0); return status; @@ -206,7 +212,7 @@ static iree_status_t iree_hal_rocm_driver_create_device_by_id( // Attempt to create the device. iree_status_t status = - iree_hal_rocm_device_create(base_driver, device_name, &driver->syms, + iree_hal_rocm_device_create(base_driver, device_name, &driver->default_params, &driver->syms, device, host_allocator, out_device); IREE_TRACE_ZONE_END(z0); diff --git a/experimental/rocm/stream_command_buffer.c b/experimental/rocm/stream_command_buffer.c new file mode 100644 index 0000000000000..f6d98df632752 --- /dev/null +++ b/experimental/rocm/stream_command_buffer.c @@ -0,0 +1,562 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/rocm/stream_command_buffer.h" + +#include "experimental/rocm/rocm_buffer.h" +#include "experimental/rocm/rocm_event.h" +#include "experimental/rocm/native_executable.h" +#include "experimental/rocm/pipeline_layout.h" +#include "experimental/rocm/status_util.h" +#include "iree/hal/utils/collective_batch.h" +#include "iree/hal/utils/resource_set.h" + +#define IREE_HAL_ROCM_MAX_BINDING_COUNT 64 +// Kernel arguments contains binding and push constants. +#define IREE_HAL_ROCM_MAX_KERNEL_ARG 128 + +typedef struct { + iree_hal_command_buffer_t base; + iree_hal_rocm_context_wrapper_t* context; + iree_hal_rocm_tracing_context_t* tracing_context; + hipStream_t stream; + + // Maintains a reference to all resources used within the command buffer. + // Reset on each begin. + iree_hal_resource_set_t* resource_set; + + // Staging arena used for host->device transfers. + // Used for when we need ROCM to be able to reference memory as it performs + // asynchronous operations. + iree_arena_allocator_t arena; + + // Iteratively constructed batch of collective operations. + iree_hal_collective_batch_t collective_batch; + + int32_t push_constant[IREE_HAL_ROCM_MAX_PUSH_CONSTANT_COUNT]; + + // Keep track of the current set of kernel arguments. + void* current_descriptor[IREE_HAL_ROCM_MAX_KERNEL_ARG]; + hipDeviceptr_t* device_ptrs[IREE_HAL_ROCM_MAX_KERNEL_ARG]; +} iree_hal_rocm_stream_command_buffer_t; + +static const iree_hal_command_buffer_vtable_t + iree_hal_rocm_stream_command_buffer_vtable; + +static iree_hal_rocm_stream_command_buffer_t* +iree_hal_rocm_stream_command_buffer_cast( + iree_hal_command_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_stream_command_buffer_vtable); + return (iree_hal_rocm_stream_command_buffer_t*)base_value; +} + +iree_status_t iree_hal_rocm_stream_command_buffer_create( + iree_hal_device_t* device, iree_hal_rocm_context_wrapper_t* context, + iree_hal_rocm_tracing_context_t* tracing_context, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_host_size_t binding_capacity, hipStream_t stream, + iree_arena_block_pool_t* block_pool, + iree_hal_command_buffer_t** out_command_buffer) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(context); + IREE_ASSERT_ARGUMENT(out_command_buffer); + *out_command_buffer = NULL; + if (binding_capacity > 0) { + // TODO(#10144): support indirect command buffers with binding tables. + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "indirect command buffers not yet implemented"); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_rocm_stream_command_buffer_t* command_buffer = NULL; + iree_status_t status = + iree_allocator_malloc(context->host_allocator, sizeof(*command_buffer), + (void**)&command_buffer); + if (iree_status_is_ok(status)) { + iree_hal_command_buffer_initialize( + device, mode, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY, + binding_capacity, &iree_hal_rocm_stream_command_buffer_vtable, + &command_buffer->base); + command_buffer->context = context; + command_buffer->tracing_context = tracing_context; + command_buffer->stream = stream; + iree_arena_initialize(block_pool, &command_buffer->arena); + for (size_t i = 0; i < IREE_HAL_ROCM_MAX_KERNEL_ARG; i++) { + command_buffer->current_descriptor[i] = &command_buffer->device_ptrs[i]; + } + + status = iree_hal_resource_set_allocate(block_pool, + &command_buffer->resource_set); + } + if (iree_status_is_ok(status)) { + iree_hal_collective_batch_initialize(&command_buffer->arena, + command_buffer->resource_set, + &command_buffer->collective_batch); + } + + *out_command_buffer = &command_buffer->base; + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_rocm_stream_command_buffer_destroy( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_collective_batch_deinitialize(&command_buffer->collective_batch); + iree_hal_resource_set_free(command_buffer->resource_set); + iree_arena_deinitialize(&command_buffer->arena); + iree_allocator_free(command_buffer->context->host_allocator, command_buffer); + + IREE_TRACE_ZONE_END(z0); +} + +bool iree_hal_rocm_stream_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer) { + return iree_hal_resource_is(&command_buffer->resource, + &iree_hal_rocm_stream_command_buffer_vtable); +} + +// Flushes any pending batched collective operations. +// Must be called before any other non-collective nodes are added to the graph +// or a barrier is encountered. +static iree_status_t iree_hal_rocm_stream_command_buffer_flush_collectives( + iree_hal_rocm_stream_command_buffer_t* command_buffer) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Collectives not implemented on ROCM"); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_begin( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + (void)command_buffer; + + IREE_ROCM_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer->tracing_context, command_buffer->stream, + /*file_name=*/NULL, 0, + /*line=*/0, /*func_name=*/NULL, 0, "iree_hal_rocm_stream_command_buffer", + strlen("iree_hal_rocm_stream_command_buffer")); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_end( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + // Reset the arena as there should be nothing using it now that we've + // dispatched all our operations inline. + // NOTE: the resource set may contain resources we need to drop as we don't + // need to keep them live any longer than it takes to schedule the + // operations. In a real command buffer we would be this stream command + // buffer is strictly used to perform inline execution/replay of + // deferred command buffers that are retaining the resources already. + // NOTE: reseting the arena invalidates the collective batch. + iree_arena_reset(&command_buffer->arena); + iree_hal_collective_batch_deinitialize(&command_buffer->collective_batch); + iree_hal_resource_set_free(command_buffer->resource_set); + IREE_RETURN_IF_ERROR(iree_hal_resource_set_allocate( + command_buffer->arena.block_pool, &command_buffer->resource_set)); + iree_hal_collective_batch_initialize(&command_buffer->arena, + command_buffer->resource_set, + &command_buffer->collective_batch); + + IREE_ROCM_TRACE_ZONE_END(command_buffer->tracing_context, + command_buffer->stream); + + return iree_ok_status(); +} + +static void iree_hal_rocm_stream_command_buffer_begin_debug_group( + iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label, + iree_hal_label_color_t label_color, + const iree_hal_label_location_t* location) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + (void)command_buffer; + + IREE_ROCM_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer->tracing_context, command_buffer->stream, + location ? location->file.data : NULL, location ? location->file.size : 0, + location ? location->line : 0, /*func_name=*/NULL, 0, label.data, + label.size); + + // TODO: pass along to CUPTI if available. +} + +static void iree_hal_rocm_stream_command_buffer_end_debug_group( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + (void)command_buffer; + + // TODO: pass along to CUPTI if available. + + IREE_ROCM_TRACE_ZONE_END(command_buffer->tracing_context, + command_buffer->stream); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_execution_barrier( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_hal_execution_barrier_flags_t flags, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { +// iree_hal_rocm_stream_command_buffer_t* command_buffer = +// iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + // TODO(raikonen): implement ROCM barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_signal_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { +// iree_hal_rocm_stream_command_buffer_t* command_buffer = +// iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + // TODO(raikonen): implement ROCM barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_reset_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { +// iree_hal_rocm_stream_command_buffer_t* command_buffer = +// iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + // TODO(raikonen): implement ROCM barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_wait_events( + iree_hal_command_buffer_t* base_command_buffer, + iree_host_size_t event_count, const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { +// iree_hal_rocm_stream_command_buffer_t* command_buffer = +// iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + // TODO(raikonen): implement ROCM barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_discard_buffer( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) { + // We could mark the memory as invalidated so that if managed ROCM does not + // try to copy it back to the host. + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_fill_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, const void* pattern, + iree_host_size_t pattern_length) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + hipDeviceptr_t target_device_buffer = iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_buffer)); + target_offset += iree_hal_buffer_byte_offset(target_buffer); + hipDeviceptr_t dst = + (hipDeviceptr_t)((uintptr_t)target_device_buffer + target_offset); + + size_t num_elements = length / pattern_length; + switch (pattern_length) { + case 4: { + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipMemsetD32Async(dst, *(const uint32_t*)(pattern), num_elements, + command_buffer->stream), + "hipMemsetD32Async"); + break; + } + case 2: { + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipMemsetD16Async(dst, *(const uint16_t*)(pattern), num_elements, + command_buffer->stream), + "hipMemsetD16Async"); + break; + } + case 1: { + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipMemsetD8Async(dst, *(const uint8_t*)(pattern), num_elements, + command_buffer->stream), + "hipMemsetD8Async"); + break; + } + default: + return iree_make_status(IREE_STATUS_INTERNAL, + "unsupported fill pattern length"); + } + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_update_buffer( + iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, + iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + // Allocate scratch space in the arena for the data and copy it in. + // The update buffer API requires that the command buffer capture the host + // memory at the time the method is called in case the caller wants to reuse + // the memory. Because ROCM memcpys are async if we didn't copy it's possible + // for the reused memory to change before the stream reaches the copy + // operation and get the wrong data. + const uint8_t* src = (const uint8_t*)source_buffer + source_offset; + if (command_buffer->arena.block_pool) { + uint8_t* storage = NULL; + IREE_RETURN_IF_ERROR( + iree_arena_allocate(&command_buffer->arena, length, (void**)&storage)); + memcpy(storage, src, length); + src = storage; + } + + // Issue the copy using the scratch memory as the source. + hipDeviceptr_t target_device_buffer = iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_buffer)); + hipDeviceptr_t dst = (hipDeviceptr_t)((uintptr_t)target_device_buffer + + iree_hal_buffer_byte_offset(target_buffer) + target_offset); + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipMemcpyHtoDAsync(dst, (void*)src, length, command_buffer->stream), + "hipMemcpyHtoDAsync"); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_copy_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + hipDeviceptr_t target_device_buffer = iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_buffer)); + target_offset += iree_hal_buffer_byte_offset(target_buffer); + hipDeviceptr_t source_device_buffer = iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(source_buffer)); + source_offset += iree_hal_buffer_byte_offset(source_buffer); + hipDeviceptr_t dst = (hipDeviceptr_t)((uintptr_t)target_device_buffer + target_offset); + hipDeviceptr_t src = (hipDeviceptr_t)((uintptr_t)source_device_buffer + source_offset); + ROCM_RETURN_IF_ERROR(command_buffer->context->syms, + hipMemcpyAsync(dst, src, length, hipMemcpyDeviceToDevice, command_buffer->stream), + "hipMemcpyAsync"); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_collective( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel, + iree_hal_collective_op_t op, uint32_t param, + iree_hal_buffer_binding_t send_binding, + iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + return iree_hal_collective_batch_append(&command_buffer->collective_batch, + channel, op, param, send_binding, + recv_binding, element_count); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_push_constants( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset, + const void* values, iree_host_size_t values_length) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + + iree_host_size_t constant_base_index = offset / sizeof(int32_t); + for (iree_host_size_t i = 0; i < values_length / sizeof(int32_t); i++) { + command_buffer->push_constant[i + constant_base_index] = + ((uint32_t*)values)[i]; + } + + return iree_ok_status(); +} + +// Tie together the binding index and its index in |bindings| array. +typedef struct { + uint32_t index; + uint32_t binding; +} iree_hal_rocm_binding_mapping_t; + +// Helper to sort the binding based on their binding index. +static int compare_binding_index(const void* a, const void* b) { + const iree_hal_rocm_binding_mapping_t buffer_a = + *(const iree_hal_rocm_binding_mapping_t*)a; + const iree_hal_rocm_binding_mapping_t buffer_b = + *(const iree_hal_rocm_binding_mapping_t*)b; + return buffer_a.binding < buffer_b.binding ? -1 : 1; +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_push_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); + + iree_host_size_t base_binding = + iree_hal_rocm_base_binding_index(pipeline_layout, set); + + // Convention with the compiler side. We map bindings to kernel argument. + // We compact the bindings to get a dense set of arguments and keep them order + // based on the binding index. + // Sort the binding based on the binding index and map the array index to the + // argument index. + iree_hal_rocm_binding_mapping_t binding_used[IREE_HAL_ROCM_MAX_BINDING_COUNT]; + for (iree_host_size_t i = 0; i < binding_count; i++) { + iree_hal_rocm_binding_mapping_t buffer = {i, bindings[i].binding}; + binding_used[i] = buffer; + } + // TODO: remove this sort - it's thankfully small (1-8 on average) but we + // should be able to avoid it like we do on the CPU side with a bitmap. + qsort(binding_used, binding_count, sizeof(iree_hal_rocm_binding_mapping_t), + compare_binding_index); + assert(binding_count < IREE_HAL_ROCM_MAX_BINDING_COUNT && + "binding count larger than the max expected."); + + for (iree_host_size_t i = 0; i < binding_count; i++) { + iree_hal_descriptor_set_binding_t binding = bindings[binding_used[i].index]; + hipDeviceptr_t device_ptr = + binding.buffer + ? (hipDeviceptr_t)((uintptr_t)iree_hal_rocm_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(binding.buffer)) + + iree_hal_buffer_byte_offset(binding.buffer) + binding.offset) + : 0; + *((hipDeviceptr_t*)command_buffer->current_descriptor[i + base_binding]) = + device_ptr; + } + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_dispatch( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + iree_hal_rocm_stream_command_buffer_t* command_buffer = + iree_hal_rocm_stream_command_buffer_cast(base_command_buffer); +// IREE_RETURN_IF_ERROR( +// iree_hal_rocm_stream_command_buffer_flush_collectives(command_buffer)); + + // Lookup kernel parameters used for side-channeling additional launch + // information from the compiler. + iree_hal_rocm_kernel_params_t kernel_params; + IREE_RETURN_IF_ERROR( + iree_hal_rocm_native_executable_entry_point_kernel_params( + executable, entry_point, &kernel_params)); + + IREE_ROCM_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer->tracing_context, command_buffer->stream, kernel_params.function_name.data, + kernel_params.function_name.size, + /*line=*/0, /*func_name=*/NULL, 0, kernel_params.function_name.data, + kernel_params.function_name.size); + + // Patch the push constants in the kernel arguments. + iree_host_size_t num_constants = + iree_hal_rocm_pipeline_layout_num_constants(kernel_params.layout); + iree_host_size_t constant_base_index = + iree_hal_rocm_push_constant_index(kernel_params.layout); + for (iree_host_size_t i = 0; i < num_constants; i++) { + *((uint32_t*)command_buffer->current_descriptor[i + constant_base_index]) = + command_buffer->push_constant[i]; + } + + ROCM_RETURN_IF_ERROR( + command_buffer->context->syms, + hipModuleLaunchKernel(kernel_params.function, workgroup_x, workgroup_y, + workgroup_z, kernel_params.block_size[0], + kernel_params.block_size[1], kernel_params.block_size[2], + kernel_params.shared_memory_size, command_buffer->stream, + command_buffer->current_descriptor, NULL), + "hipModuleLaunchKernel"); + IREE_ROCM_TRACE_ZONE_END(command_buffer->tracing_context, + command_buffer->stream); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_dispatch_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_t* workgroups_buffer, + iree_device_size_t workgroups_offset) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "need rocm implementation of dispatch indirect"); +} + +static iree_status_t iree_hal_rocm_stream_command_buffer_execute_commands( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_command_buffer_t* base_commands, + iree_hal_buffer_binding_table_t binding_table) { + // TODO(#10144): support indirect command buffers with deferred command + // buffers or graphs. We likely just want to switch to graphs. + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "indirect command buffers not yet implemented"); +} + +static const iree_hal_command_buffer_vtable_t + iree_hal_rocm_stream_command_buffer_vtable = { + .destroy = iree_hal_rocm_stream_command_buffer_destroy, + .begin = iree_hal_rocm_stream_command_buffer_begin, + .end = iree_hal_rocm_stream_command_buffer_end, + .begin_debug_group = + iree_hal_rocm_stream_command_buffer_begin_debug_group, + .end_debug_group = iree_hal_rocm_stream_command_buffer_end_debug_group, + .execution_barrier = + iree_hal_rocm_stream_command_buffer_execution_barrier, + .signal_event = iree_hal_rocm_stream_command_buffer_signal_event, + .reset_event = iree_hal_rocm_stream_command_buffer_reset_event, + .wait_events = iree_hal_rocm_stream_command_buffer_wait_events, + .discard_buffer = iree_hal_rocm_stream_command_buffer_discard_buffer, + .fill_buffer = iree_hal_rocm_stream_command_buffer_fill_buffer, + .update_buffer = iree_hal_rocm_stream_command_buffer_update_buffer, + .copy_buffer = iree_hal_rocm_stream_command_buffer_copy_buffer, + .collective = iree_hal_rocm_stream_command_buffer_collective, + .push_constants = iree_hal_rocm_stream_command_buffer_push_constants, + .push_descriptor_set = + iree_hal_rocm_stream_command_buffer_push_descriptor_set, + .dispatch = iree_hal_rocm_stream_command_buffer_dispatch, + .dispatch_indirect = + iree_hal_rocm_stream_command_buffer_dispatch_indirect, + .execute_commands = + iree_hal_rocm_stream_command_buffer_execute_commands, +}; diff --git a/experimental/rocm/stream_command_buffer.h b/experimental/rocm/stream_command_buffer.h new file mode 100644 index 0000000000000..691fa63809ffa --- /dev/null +++ b/experimental/rocm/stream_command_buffer.h @@ -0,0 +1,49 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_DRIVERS_ROCM_STREAM_COMMAND_BUFFER_H_ +#define IREE_HAL_DRIVERS_ROCM_STREAM_COMMAND_BUFFER_H_ + +#include "iree/base/internal/arena.h" +#include "iree/hal/api.h" +#include "experimental/rocm/context_wrapper.h" +#include "experimental/rocm/rocm_headers.h" +#include "experimental/rocm/dynamic_symbols.h" +#include "experimental/rocm/tracing.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a ROCM stream command buffer that immediately issues commands against +// the given |stream|. Access to |stream| must be synchronized by the user. +// +// If |block_pool| is non-NULL then the stream command buffer will retain copies +// of input data until reset. If NULL then the caller must ensure the lifetime +// of input data outlives the command buffer. +// +// This command buffer is used to both replay deferred command buffers and +// perform inline execution. When replaying the scratch data required for things +// like buffer updates is retained by the source deferred command buffer and as +// such the |block_pool| and can be NULL to avoid a double copy. +iree_status_t iree_hal_rocm_stream_command_buffer_create( + iree_hal_device_t* device, iree_hal_rocm_context_wrapper_t* context, + iree_hal_rocm_tracing_context_t* tracing_context, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_host_size_t binding_capacity, hipStream_t stream, + iree_arena_block_pool_t* block_pool, + iree_hal_command_buffer_t** out_command_buffer); + +// Returns true if |command_buffer| is a ROCM stream-based command buffer. +bool iree_hal_rocm_stream_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_DRIVERS_ROCM_STREAM_COMMAND_BUFFER_H_ diff --git a/runtime/src/iree/schemas/rocm_executable_def.fbs b/runtime/src/iree/schemas/rocm_executable_def.fbs index 3735f4b56c474..a2a20fed1db49 100644 --- a/runtime/src/iree/schemas/rocm_executable_def.fbs +++ b/runtime/src/iree/schemas/rocm_executable_def.fbs @@ -26,6 +26,9 @@ table ExecutableDef { // block_sizes:[BlockSizeDef]; + // Size of dynamic shared memory. + shared_memory_size:[uint32]; + // HSACO string of the module. hsaco_image:string; } From ff129d9ca1388dcebdc8ee74693303e705f56a21 Mon Sep 17 00:00:00 2001 From: stanley-nod Date: Mon, 18 Sep 2023 02:38:26 -0700 Subject: [PATCH 38/44] [ROCM] Add supports_concurrent_managed_access --- experimental/rocm/dynamic_symbol_tables.h | 1 + experimental/rocm/rocm_allocator.c | 120 ++++++++++++++++++---- experimental/rocm/rocm_allocator.h | 2 + experimental/rocm/rocm_device.c | 1 + 4 files changed, 106 insertions(+), 18 deletions(-) diff --git a/experimental/rocm/dynamic_symbol_tables.h b/experimental/rocm/dynamic_symbol_tables.h index 5cc1ea62a248e..678635b984563 100644 --- a/experimental/rocm/dynamic_symbol_tables.h +++ b/experimental/rocm/dynamic_symbol_tables.h @@ -35,6 +35,7 @@ RC_PFN_DECL(hipMemcpyAsync, void *, const void *, size_t, hipMemcpyKind, hipStream_t) RC_PFN_DECL(hipMalloc, void **, size_t) RC_PFN_DECL(hipMallocManaged, hipDeviceptr_t *, size_t, unsigned int) +RC_PFN_DECL(hipMemPrefetchAsync, const void *, size_t, int, hipStream_t) RC_PFN_DECL(hipFree, void *) RC_PFN_DECL(hipHostFree, void *) RC_PFN_DECL(hipMemAllocHost, void **, size_t, unsigned int) diff --git a/experimental/rocm/rocm_allocator.c b/experimental/rocm/rocm_allocator.c index 3c63c71ec1bd7..84dfb32f81c87 100644 --- a/experimental/rocm/rocm_allocator.c +++ b/experimental/rocm/rocm_allocator.c @@ -15,8 +15,11 @@ typedef struct iree_hal_rocm_allocator_t { iree_hal_resource_t resource; - iree_hal_device_t* base_device; iree_hal_rocm_context_wrapper_t* context; + hipDevice_t device; + hipStream_t stream; + + bool supports_concurrent_managed_access; IREE_STATISTICS(iree_hal_allocator_statistics_t statistics;) } iree_hal_rocm_allocator_t; @@ -30,10 +33,30 @@ static iree_hal_rocm_allocator_t* iree_hal_rocm_allocator_cast( } iree_status_t iree_hal_rocm_allocator_create( - iree_hal_rocm_context_wrapper_t* context, + iree_hal_rocm_context_wrapper_t* context, hipDevice_t device, hipStream_t stream, iree_hal_allocator_t** out_allocator) { IREE_ASSERT_ARGUMENT(context); IREE_TRACE_ZONE_BEGIN(z0); + + // To support device-local + host-visible memory we need concurrent managed + // access indicating that the host and devices can concurrently access the + // device memory. If we don't have this feature then we fall back to forcing + // all device-local + host-visible memory into host-local + device-visible + // page-locked memory. The compiler tries to avoid this for high-traffic + // buffers except for readback staging buffers. + int supports_concurrent_managed_access = 0; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, ROCM_RESULT_TO_STATUS( + context->syms, + hipDeviceGetAttribute( + &supports_concurrent_managed_access, + hipDeviceAttributeConcurrentManagedAccess, device), + "hipDeviceGetAttribute")); + IREE_TRACE_ZONE_APPEND_TEXT( + z0, supports_concurrent_managed_access + ? "has CONCURRENT_MANAGED_ACCESS" + : "no CONCURRENT_MANAGED_ACCESS (expect slow accesses on " + "device-local + host-visible memory)"); iree_hal_rocm_allocator_t* allocator = NULL; iree_status_t status = iree_allocator_malloc( context->host_allocator, sizeof(*allocator), (void**)&allocator); @@ -41,6 +64,9 @@ iree_status_t iree_hal_rocm_allocator_create( iree_hal_resource_initialize(&iree_hal_rocm_allocator_vtable, &allocator->resource); allocator->context = context; + allocator->device = device; + allocator->stream = stream; + allocator->supports_concurrent_managed_access = supports_concurrent_managed_access !=0; *out_allocator = (iree_hal_allocator_t*)allocator; } @@ -87,24 +113,31 @@ static iree_status_t iree_hal_rocm_allocator_query_memory_heaps( iree_host_size_t capacity, iree_hal_allocator_memory_heap_t* IREE_RESTRICT heaps, iree_host_size_t* IREE_RESTRICT out_count) { - const iree_host_size_t count = 3; + iree_hal_rocm_allocator_t* allocator = + iree_hal_rocm_allocator_cast(base_allocator); + + // TODO(benvanik): check CU_DEVICE_ATTRIBUTE_INTEGRATED and return a unified + // set of heaps (likely still a cached and uncached, at minimum). + iree_host_size_t count = 3; + if (allocator->supports_concurrent_managed_access) { + ++count; // device-local | host-visible + } if (out_count) *out_count = count; if (capacity < count) { // NOTE: lightweight as this is hit in normal pre-sizing usage. return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); } - // NOTE: this is all a guess - someone who is familiar with rocm will want - // to refine this further. - // Don't think there's a query for these. // Max allocation size may be much smaller in certain memory types such as // page-locked memory and it'd be good to enforce that. const iree_device_size_t max_allocation_size = ~(iree_device_size_t)0; const iree_device_size_t min_alignment = 64; + int i = 0; + // Device-local memory (dispatch resources): - heaps[0] = (iree_hal_allocator_memory_heap_t){ + heaps[i++] = (iree_hal_allocator_memory_heap_t){ .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_DISPATCH, @@ -112,27 +145,46 @@ static iree_status_t iree_hal_rocm_allocator_query_memory_heaps( .min_alignment = min_alignment, }; + if (allocator->supports_concurrent_managed_access) { + // Device-local managed memory with host mapping support: + heaps[i++] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | + IREE_HAL_MEMORY_TYPE_HOST_COHERENT, + .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_MAPPING, + .max_allocation_size = max_allocation_size, + .min_alignment = min_alignment, + }; + } + // Write-combined page-locked host-local memory (upload): - heaps[1] = (iree_hal_allocator_memory_heap_t){ - .type = - IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_COHERENT, - .allowed_usage = - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + heaps[i++] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_COHERENT, + .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_MAPPING, .max_allocation_size = max_allocation_size, .min_alignment = min_alignment, }; // Cached page-locked host-local memory (download): - heaps[2] = (iree_hal_allocator_memory_heap_t){ - .type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + heaps[i++] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_COHERENT | IREE_HAL_MEMORY_TYPE_HOST_CACHED, - .allowed_usage = - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_MAPPING, .max_allocation_size = max_allocation_size, .min_alignment = min_alignment, }; + IREE_ASSERT(i == count); return iree_ok_status(); } @@ -141,22 +193,46 @@ iree_hal_rocm_allocator_query_buffer_compatibility( iree_hal_allocator_t* IREE_RESTRICT base_allocator, iree_hal_buffer_params_t* IREE_RESTRICT params, iree_device_size_t* IREE_RESTRICT allocation_size) { + iree_hal_rocm_allocator_t* allocator = + iree_hal_rocm_allocator_cast(base_allocator); + // All buffers can be allocated on the heap. iree_hal_buffer_compatibility_t compatibility = IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE; - if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { - compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; + // Buffers are importable in ROCM under most cases, though performance may + // vary wildly. We don't fully verify that the buffer parameters are + // self-consistent and just look at whether we can get a device pointer. + if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_IMPORTABLE; } // Buffers can only be used on the queue if they are device visible. if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { + if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; + } if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE)) { compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH; } } + // If concurrent managed access is not supported then make device-local + + // host-visible allocations fall back to host-local + device-visible + // page-locked memory. This will be significantly slower for the device to + // access but the compiler only uses this type for readback staging buffers + // and it's better to function than function fast. + if (!allocator->supports_concurrent_managed_access && + iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_LOW_PERFORMANCE; + params->type &= ~(IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE); + params->type |= + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE; + } + // We are now optimal. params->type &= ~IREE_HAL_MEMORY_TYPE_OPTIMAL; @@ -209,6 +285,14 @@ static iree_status_t iree_hal_rocm_allocator_allocate_buffer( status = ROCM_RESULT_TO_STATUS( allocator->context->syms, hipMallocManaged(&device_ptr, allocation_size, hipMemAttachGlobal)); + if (iree_status_is_ok(status) && + allocator->supports_concurrent_managed_access) { + // Prefetch the buffer on the GPU device. + status = ROCM_RESULT_TO_STATUS( + allocator->context->syms, + hipMemPrefetchAsync(device_ptr, allocation_size, allocator->device, + allocator->stream)); + } host_ptr = (void*)device_ptr; } else { // Device only. diff --git a/experimental/rocm/rocm_allocator.h b/experimental/rocm/rocm_allocator.h index a2a89eab2cdd4..c735e830b0137 100644 --- a/experimental/rocm/rocm_allocator.h +++ b/experimental/rocm/rocm_allocator.h @@ -19,6 +19,8 @@ extern "C" { // Create a ROCM allocator. iree_status_t iree_hal_rocm_allocator_create( iree_hal_rocm_context_wrapper_t* context, + hipDevice_t device, + hipStream_t stream, iree_hal_allocator_t** out_allocator); #ifdef __cplusplus diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c index ba52d2ada00bd..ad677d0c29efb 100644 --- a/experimental/rocm/rocm_device.c +++ b/experimental/rocm/rocm_device.c @@ -139,6 +139,7 @@ static iree_status_t iree_hal_rocm_device_create_internal( } if (iree_status_is_ok(status)) { status = iree_hal_rocm_allocator_create(&device->context_wrapper, + device->device, device->stream, &device->device_allocator); } if (iree_status_is_ok(status) && From e41602b02b1ea512ca4b8bb29e95802b6c1f1132 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 18 Sep 2023 16:10:07 -0400 Subject: [PATCH 39/44] Set preferred location to the device for HIP Managed Memory The semantics for specifying different kinds of advice is unclear so I set it in two stages. --- experimental/rocm/dynamic_symbol_tables.h | 1 + experimental/rocm/hip_headers.h | 11 +++++++++++ experimental/rocm/rocm_allocator.c | 10 ++++++++++ 3 files changed, 22 insertions(+) diff --git a/experimental/rocm/dynamic_symbol_tables.h b/experimental/rocm/dynamic_symbol_tables.h index 678635b984563..a4d1c6c21eb5b 100644 --- a/experimental/rocm/dynamic_symbol_tables.h +++ b/experimental/rocm/dynamic_symbol_tables.h @@ -24,6 +24,7 @@ RC_PFN_DECL(hipInit, unsigned int) RC_PFN_DECL(hipModuleLaunchKernel, hipFunction_t, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, hipStream_t, void **, void **) +RC_PFN_DECL(hipMemAdvise, const void *, size_t, int, int) RC_PFN_DECL(hipMemset, void *, int, size_t) RC_PFN_DECL(hipMemsetAsync, void *, int, size_t, hipStream_t) RC_PFN_DECL(hipMemsetD32Async, void *, int, size_t, hipStream_t) diff --git a/experimental/rocm/hip_headers.h b/experimental/rocm/hip_headers.h index 15193c17e9622..26c9d51ffe751 100644 --- a/experimental/rocm/hip_headers.h +++ b/experimental/rocm/hip_headers.h @@ -234,6 +234,17 @@ typedef enum HIPmemAttach_flags_enum { HIP_MEM_ATTACH_SINGLE = 0x4, } HIPmemAttach_flags; +typedef enum HIPmemoryAdvise_flags_enum { + hipMemAdviseSetReadMostly = 1, + hipMemAdviseUnsetReadMostly = 2, + hipMemAdviseSetPreferredLocation = 3, + hipMemAdviseUnsetPreferredLocation = 4, + hipMemAdviseSetAccessedBy = 5, + hipMemAdviseUnsetAccessedBy = 6, + hipMemAdviseSetCoarseGrain = 100, + hipMemAdviseUnsetCoarseGrain = 101, +} HIPmemoryAdvise_flags; + typedef enum HIPctx_flags_enum { hipDeviceScheduleAuto = 0x00, hipDeviceScheduleSpin = 0x01, diff --git a/experimental/rocm/rocm_allocator.c b/experimental/rocm/rocm_allocator.c index 84dfb32f81c87..dbd0ea4b99366 100644 --- a/experimental/rocm/rocm_allocator.c +++ b/experimental/rocm/rocm_allocator.c @@ -285,6 +285,16 @@ static iree_status_t iree_hal_rocm_allocator_allocate_buffer( status = ROCM_RESULT_TO_STATUS( allocator->context->syms, hipMallocManaged(&device_ptr, allocation_size, hipMemAttachGlobal)); + if (iree_status_is_ok(status)) { + status = ROCM_RESULT_TO_STATUS( + allocator->context->syms, + hipMemAdvise(device_ptr, allocation_size, + hipMemAdviseSetPreferredLocation, allocator->device)); + status = ROCM_RESULT_TO_STATUS( + allocator->context->syms, + hipMemAdvise(device_ptr, allocation_size, + hipMemAdviseSetCoarseGrain, allocator->device)); + } if (iree_status_is_ok(status) && allocator->supports_concurrent_managed_access) { // Prefetch the buffer on the GPU device. From 342940f51b8ca11266e9762537d56125f46d59f6 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 8 Jun 2023 14:50:09 -0700 Subject: [PATCH 40/44] Making execution region results queue-ordered allocas. We don't currently insert deallocas and don't track live ranges but that can come in the future as we support more control flow. For now this at least gets all of the common allocations within an invocation into the queue-ordered bucket so that we can do proper async execution and use native queue-ordered (e.g. stream-ordered allocations in CUDA) functionality. With this change the caching allocator is no longer needed for CUDA in almost all cases. --- .../HAL/Conversion/StreamToHAL/Patterns.cpp | 45 +++---- .../StreamToHAL/test/resource_ops.mlir | 32 ++--- .../MaterializeDispatchInstrumentation.cpp | 8 +- .../compiler/Dialect/Stream/IR/StreamOps.cpp | 126 ++++++++++++++++-- .../compiler/Dialect/Stream/IR/StreamOps.td | 50 ++++--- .../Dialect/Stream/IR/test/resource_ops.mlir | 8 +- .../Dialect/Stream/Transforms/BUILD.bazel | 1 - .../Dialect/Stream/Transforms/CMakeLists.txt | 1 - .../Stream/Transforms/PackAllocations.cpp | 119 ----------------- .../Stream/Transforms/PackConstants.cpp | 2 +- .../Dialect/Stream/Transforms/Passes.cpp | 3 - .../Dialect/Stream/Transforms/Passes.h | 1 - .../Dialect/Stream/Transforms/Passes.td | 8 -- .../Stream/Transforms/ScheduleAllocation.cpp | 23 ++-- .../Stream/Transforms/test/BUILD.bazel | 1 - .../Stream/Transforms/test/CMakeLists.txt | 1 - .../Transforms/test/pack_allocations.mlir | 38 ------ .../Transforms/test/schedule_allocation.mlir | 81 ++++++----- .../Conversion/StreamToHALInline/Patterns.cpp | 13 +- experimental/cuda2/cuda_buffer.c | 8 ++ .../src/iree/hal/drivers/cuda/cuda_buffer.c | 8 ++ 21 files changed, 262 insertions(+), 315 deletions(-) delete mode 100644 compiler/src/iree/compiler/Dialect/Stream/Transforms/PackAllocations.cpp delete mode 100644 compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_allocations.mlir diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 9265878ec6647..8d31a02822972 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -332,26 +332,19 @@ struct ResourceAllocOpPattern lookupAllocatorAndQueueAffinityFor(allocOp, rewriter); auto bufferType = rewriter.getType(); - SmallVector results; - for (auto [resourceResult, storageSize] : - llvm::zip_equal(allocOp.getResults(), allocOp.getStorageSizes())) { - auto resourceType = - llvm::cast(resourceResult.getType()); - - auto memoryTypes = IREE::HAL::MemoryTypeBitfield::None; - auto bufferUsage = IREE::HAL::BufferUsageBitfield::None; - if (failed(deriveAllowedResourceBufferBits(allocOp.getLoc(), resourceType, - memoryTypes, bufferUsage))) { - return failure(); - } + auto resourceType = + cast(allocOp.getResult().getType()); - auto allocateOp = rewriter.create( - allocOp.getLoc(), bufferType, allocator, queueAffinity, memoryTypes, - bufferUsage, storageSize); - results.push_back(allocateOp.getResult()); + auto memoryTypes = IREE::HAL::MemoryTypeBitfield::None; + auto bufferUsage = IREE::HAL::BufferUsageBitfield::None; + if (failed(deriveAllowedResourceBufferBits(allocOp.getLoc(), resourceType, + memoryTypes, bufferUsage))) { + return failure(); } - rewriter.replaceOp(allocOp, results); + rewriter.replaceOpWithNewOp( + allocOp, bufferType, allocator, queueAffinity, memoryTypes, bufferUsage, + adaptor.getStorageSize()); return success(); } }; @@ -367,16 +360,14 @@ struct ResourceAllocaOpPattern lookupDeviceAndQueueAffinityFor(allocaOp, rewriter); auto bufferType = rewriter.getType(); - // Transient allocations are device-local. Copies are required to get their - // contents back on the host/another device. - auto memoryTypes = IREE::HAL::MemoryTypeBitfield::DeviceLocal; - - // TODO(benvanik): refine usage. - // We know by construction that transient buffers are not host visible and - // as such can only be used for device commands. We should be able to more - // closely limit to just dispatch or transfer though. - auto bufferUsage = IREE::HAL::BufferUsageBitfield::Transfer | - IREE::HAL::BufferUsageBitfield::DispatchStorage; + auto resourceType = + cast(allocaOp.getResult().getType()); + auto memoryTypes = IREE::HAL::MemoryTypeBitfield::None; + auto bufferUsage = IREE::HAL::BufferUsageBitfield::None; + if (failed(deriveAllowedResourceBufferBits(loc, resourceType, memoryTypes, + bufferUsage))) { + return failure(); + } // Gather wait/signal fence, which are optional. Value waitFence = diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir index e70f8b04263a5..88cb014dff42e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir @@ -1,25 +1,21 @@ // RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s // CHECK-LABEL: @resourceAlloc -func.func @resourceAlloc(%arg0: index, %arg1: index) -> (!stream.resource, !stream.resource) { +func.func @resourceAlloc(%arg0: index) -> !stream.resource { // CHECK: %[[RET0:.+]] = hal.allocator.allocate // CHECK-SAME: type("DeviceVisible|DeviceLocal") // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") // CHECK-SAME: : !hal.buffer{%arg0} - // CHECK-NEXT: %[[RET1:.+]] = hal.allocator.allocate - // CHECK-SAME: type("DeviceVisible|DeviceLocal") - // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") - // CHECK-SAME: : !hal.buffer{%arg1} - %0:2 = stream.resource.alloc uninitialized : !stream.resource{%arg0}, !stream.resource{%arg1} - // CHECK: return %[[RET0]], %[[RET1]] - return %0#0, %0#1 : !stream.resource, !stream.resource + %0 = stream.resource.alloc uninitialized : !stream.resource{%arg0} + // CHECK: return %[[RET0]] + return %0 : !stream.resource } // ----- // CHECK-LABEL: @resourceAlloca // CHECK-SAME: (%[[SIZE:.+]]: index) -func.func @resourceAlloca(%size: index) -> (!stream.resource, !stream.timepoint) { +func.func @resourceAlloca(%size: index) -> (!stream.resource, !stream.timepoint) { // CHECK: %[[WAIT_FENCE:.+]] = util.null : !hal.fence // CHECK: %[[SIGNAL_FENCE:.+]] = hal.fence.create // CHECK: %[[RET0:.+]] = hal.device.queue.alloca @@ -30,16 +26,16 @@ func.func @resourceAlloca(%size: index) -> (!stream.resource, !stream.t // CHECK-SAME: type("DeviceVisible|DeviceLocal") // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") // CHECK-SAME: : !hal.buffer{%[[SIZE]]} - %0:2 = stream.resource.alloca uninitialized : !stream.resource{%size} => !stream.timepoint + %0:2 = stream.resource.alloca uninitialized : !stream.resource{%size} => !stream.timepoint // CHECK: return %[[RET0]], %[[SIGNAL_FENCE]] - return %0#0, %0#1 : !stream.resource, !stream.timepoint + return %0#0, %0#1 : !stream.resource, !stream.timepoint } // ----- // CHECK-LABEL: @resourceAllocaAwait // CHECK-SAME: (%[[SIZE:.+]]: index, %[[WAIT_FENCE:.+]]: !hal.fence) -func.func @resourceAllocaAwait(%size: index, %await_timepoint: !stream.timepoint) -> (!stream.resource, !stream.timepoint) { +func.func @resourceAllocaAwait(%size: index, %await_timepoint: !stream.timepoint) -> (!stream.resource, !stream.timepoint) { // CHECK: %[[SIGNAL_FENCE:.+]] = hal.fence.create // CHECK: %[[RET0:.+]] = hal.device.queue.alloca // CHECK-SAME: affinity(%c-1 @@ -49,16 +45,16 @@ func.func @resourceAllocaAwait(%size: index, %await_timepoint: !stream.timepoint // CHECK-SAME: type("DeviceVisible|DeviceLocal") // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") // CHECK-SAME: : !hal.buffer{%[[SIZE]]} - %0:2 = stream.resource.alloca uninitialized await(%await_timepoint) => !stream.resource{%size} => !stream.timepoint + %0:2 = stream.resource.alloca uninitialized await(%await_timepoint) => !stream.resource{%size} => !stream.timepoint // CHECK: return %[[RET0]], %[[SIGNAL_FENCE]] - return %0#0, %0#1 : !stream.resource, !stream.timepoint + return %0#0, %0#1 : !stream.resource, !stream.timepoint } // ----- // CHECK-LABEL: @resourceDealloca // CHECK-SAME: (%[[SIZE:.+]]: index, %[[RESOURCE:.+]]: !hal.buffer) -func.func @resourceDealloca(%size: index, %resource: !stream.resource) -> !stream.timepoint { +func.func @resourceDealloca(%size: index, %resource: !stream.resource) -> !stream.timepoint { // CHECK: %[[WAIT_FENCE:.+]] = util.null : !hal.fence // CHECK: %[[SIGNAL_FENCE:.+]] = hal.fence.create // CHECK: hal.device.queue.dealloca @@ -66,7 +62,7 @@ func.func @resourceDealloca(%size: index, %resource: !stream.resource) // CHECK-SAME: wait(%[[WAIT_FENCE]]) // CHECK-SAME: signal(%[[SIGNAL_FENCE]]) // CHECK-SAME: buffer(%[[RESOURCE]] : !hal.buffer) - %0 = stream.resource.dealloca %resource : !stream.resource{%size} => !stream.timepoint + %0 = stream.resource.dealloca %resource : !stream.resource{%size} => !stream.timepoint // CHECK: return %[[SIGNAL_FENCE]] return %0 : !stream.timepoint } @@ -77,14 +73,14 @@ func.func @resourceDealloca(%size: index, %resource: !stream.resource) // CHECK-LABEL: @resourceDeallocaAwait // CHECK-SAME: (%[[SIZE:.+]]: index, %[[RESOURCE:.+]]: !hal.buffer, %[[WAIT_FENCE:.+]]: !hal.fence) -func.func @resourceDeallocaAwait(%size: index, %resource: !stream.resource, %await_timepoint: !stream.timepoint) -> !stream.timepoint { +func.func @resourceDeallocaAwait(%size: index, %resource: !stream.resource, %await_timepoint: !stream.timepoint) -> !stream.timepoint { // CHECK: %[[SIGNAL_FENCE:.+]] = hal.fence.create // CHECK: hal.device.queue.dealloca // CHECK-SAME: affinity(%c-1 // CHECK-SAME: wait(%[[WAIT_FENCE]]) // CHECK-SAME: signal(%[[SIGNAL_FENCE]]) // CHECK-SAME: buffer(%[[RESOURCE]] : !hal.buffer) - %0 = stream.resource.dealloca await(%await_timepoint) => %resource : !stream.resource{%size} => !stream.timepoint + %0 = stream.resource.dealloca await(%await_timepoint) => %resource : !stream.resource{%size} => !stream.timepoint // CHECK: return %[[SIGNAL_FENCE]] return %0 : !stream.timepoint } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp index ec7dc3064412f..e3ce36f7cd9c7 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp @@ -161,11 +161,9 @@ class MaterializeDispatchInstrumentationPass OpBuilder::atBlockBegin(initializerOp.addEntryBlock()); Value bufferSize = initializerBuilder.create(loc, bufferSizeAttr); - Value buffer = initializerBuilder - .create( - loc, globalOp.getType(), bufferSize, - /*uninitialized=*/true, /*affinity=*/nullptr) - .getResult(0); + Value buffer = initializerBuilder.create( + loc, globalOp.getType(), bufferSize, + /*uninitialized=*/true, /*affinity=*/nullptr); initializerBuilder.create(loc, buffer, globalOp); initializerBuilder.create(loc); diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp index f7d30afe49643..5b15968841f9c 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp @@ -615,22 +615,124 @@ static void printWorkgroupCountRegion(OpAsmPrinter &p, Operation *op, // stream.resource.alloc //===----------------------------------------------------------------------===// -LogicalResult ResourceAllocOp::verify() { - ResourceAllocOp op = *this; - if (failed(verifyOpValueSizes(op, op.getResults(), op.getStorageSizes()))) { - return failure(); +// static +std::pair> +ResourceAllocOp::createSuballocations( + Type resourceType, ArrayRef locs, ValueRange storageSizes, + bool uninitialized, AffinityAttr affinityAttr, OpBuilder &builder) { + assert(locs.size() == storageSizes.size() && + "expect locs and storageSizes to match"); + if (locs.empty()) + return {}; + if (locs.size() == 1) { + auto allocOp = builder.create( + locs.front(), resourceType, storageSizes.front(), uninitialized, + affinityAttr); + return {allocOp, {allocOp.getResult()}}; } + auto fusedLoc = builder.getFusedLoc(locs); - // All allocated resources must have the same lifetime. - auto anyType = op.getResults().front().getType(); - for (auto type : op.getResultTypes()) { - if (type != anyType) { - return op.emitError() - << "all allocated resources must have the same lifetime"; - } + // NOTE: this is risky: we are assuming right now that all of the + // allocations will fit within the constraints of the system. This is not + // guaranteed: a very low maximum buffer range may lead to packed slabs + // that are not fully addressable. For now we are processing models with + // small enough workloads and our target devices are relatively lax on + // things so long as we stay under UINT32_MAX boundaries. + + // All slices are 0-0 (overlapping). + size_t sliceCount = locs.size(); + SmallVector lifetimeIntervals(sliceCount * 2, 0); + + // Compute total size and the offsets of all suballocated resources via the + // pack op. + auto indexType = builder.getIndexType(); + SmallVector packedOffsetTypes(sliceCount, indexType); + auto packOp = builder.create( + fusedLoc, indexType, packedOffsetTypes, /*offset=*/nullptr, + builder.getIndexArrayAttr(lifetimeIntervals), storageSizes, affinityAttr); + + // Create the new alloca based on the total required size. + auto allocOp = builder.create( + fusedLoc, resourceType, packOp.getTotalLength(), uninitialized, + affinityAttr); + auto slab = allocOp.getResult(); + auto slabSize = packOp.getTotalLength(); + + // Create subviews for all of the suballocated resources. + SmallVector results; + for (auto [loc, subviewOffset, subviewLength] : + llvm::zip_equal(locs, packOp.getPackedOffsets(), storageSizes)) { + results.push_back(builder + .create( + loc, slab, slabSize, subviewOffset, subviewLength) + .getResult()); } + return {allocOp, results}; +} - return success(); +//===----------------------------------------------------------------------===// +// stream.resource.alloca +//===----------------------------------------------------------------------===// + +// static +std::pair> +ResourceAllocaOp::createSuballocations(Type timepointType, Type resourceType, + ArrayRef locs, + ValueRange storageSizes, + Value awaitTimepoint, + AffinityAttr affinityAttr, + OpBuilder &builder) { + assert(locs.size() == storageSizes.size() && + "expect locs and storageSizes to match"); + if (locs.empty()) + return {}; + if (locs.size() == 1) { + auto allocaOp = builder.create( + locs.front(), resourceType, timepointType, storageSizes.front(), + awaitTimepoint, affinityAttr); + return {allocaOp, {allocaOp.getResult()}}; + } + auto fusedLoc = builder.getFusedLoc(locs); + + // NOTE: this is risky: we are assuming right now that all of the + // allocations will fit within the constraints of the system. This is not + // guaranteed: a very low maximum buffer range may lead to packed slabs + // that are not fully addressable. For now we are processing models with + // small enough workloads and our target devices are relatively lax on + // things so long as we stay under UINT32_MAX boundaries. If a user starts + // hitting this the solution is to do in-place outputs such that we don't + // need to allocate them; when possible that's always going to be better than + // leaving them to the IREE compiled program to deal with. + + // All slices are 0-0 (overlapping). + size_t sliceCount = locs.size(); + SmallVector lifetimeIntervals(sliceCount * 2, 0); + + // Compute total size and the offsets of all suballocated resources via the + // pack op. + auto indexType = builder.getIndexType(); + SmallVector packedOffsetTypes(sliceCount, indexType); + auto packOp = builder.create( + fusedLoc, indexType, packedOffsetTypes, /*offset=*/nullptr, + builder.getIndexArrayAttr(lifetimeIntervals), storageSizes, affinityAttr); + + // Create the new alloca based on the total required size. + auto allocaOp = builder.create( + fusedLoc, resourceType, timepointType, packOp.getTotalLength(), + awaitTimepoint, affinityAttr); + auto slab = allocaOp.getResult(); + auto slabSize = packOp.getTotalLength(); + + // Create subviews for all of the suballocated resources. + SmallVector results; + for (auto [loc, subviewOffset, subviewLength] : + llvm::zip_equal(locs, packOp.getPackedOffsets(), storageSizes)) { + results.push_back(builder + .create( + loc, slab, slabSize, subviewOffset, subviewLength) + .getResult()); + } + return {allocaOp, results}; } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td index d49ad4a02c4b4..9396301107a3a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td @@ -42,7 +42,7 @@ def Stream_ResourceAllocOp : Stream_Op<"resource.alloc", [ AlwaysSpeculatable, MemoryEffects<[MemAlloc]>, ]> { - let summary = [{allocates a persistent value with undefined contents}]; + let summary = [{allocates a persistent resource}]; let description = [{ Allocates a persistent value (one that is long-lived and possibly external to the program) with undefined contents. Consumers of the allocated @@ -58,31 +58,39 @@ def Stream_ResourceAllocOp : Stream_Op<"resource.alloc", [ separate allocations may be fused into one or more slab allocations in order to reduce overheads. How many allocations can be fused is based on the size of the individual resources and the target constraints (how large any single - buffer may be, etc). At the stream dialect level treat a multi-result alloc - as a way to indicate similar lifetimes. + buffer may be, etc). }]; let arguments = (ins - Variadic:$storage_sizes, + Stream_Size:$storage_size, UnitAttr:$uninitialized, OptionalAttr:$affinity ); let results = (outs - Variadic:$results + Stream_AnyResource:$result ); let assemblyFormat = [{ (`on` `(` $affinity^ `)`)? (`uninitialized` $uninitialized^)? - attr-dict `:` custom(type($results), $storage_sizes) + attr-dict `:` + type($result) `{` $storage_size `}` }]; let extraClassDeclaration = [{ Value getOperandSize(unsigned idx) { return {}; } - Value getResultSize(unsigned idx) { return getStorageSizes()[idx]; } - }]; + Value getResultSize(unsigned idx) { return getStorageSize(); } - let hasVerifier = 1; + // Creates a single shared allocation for multiple suballocations. + // Suballocations are defined by entries in the struct-of-arrays-style + // `{locs, storageSizes}` set. Currently all result types must match. + // Returns the allocation and subviews into all suballocated resources. + static std::pair> + createSuballocations( + Type resourceType, + ArrayRef locs, ValueRange storageSizes, + bool uninitialized, AffinityAttr affinityAttr, OpBuilder &builder); + }]; let hasCanonicalizer = 1; } @@ -113,10 +121,7 @@ def Stream_ResourceAllocaOp : Stream_Op<"resource.alloca", [ OptionalAttr:$affinity ); let results = (outs - AnyTypeOf<[ - Stream_StagingResource, - Stream_TransientResource, - ]>:$result, + Stream_AnyResource:$result, Stream_Timepoint:$result_timepoint ); @@ -136,6 +141,16 @@ def Stream_ResourceAllocaOp : Stream_Op<"resource.alloca", [ SmallVector getAwaitTimepoints() { if (getAwaitTimepoint()) return {getAwaitTimepoint()}; else return {}; } + + // Creates a single shared allocation for multiple suballocations. + // Suballocations are defined by entries in the struct-of-arrays-style + // `{locs, storageSizes}` set. Currently all result types must match. + // Returns the allocation and subviews into all suballocated resources. + static std::pair> + createSuballocations( + Type timepointType, Type resourceType, + ArrayRef locs, ValueRange storageSizes, + Value awaitTimepoint, AffinityAttr affinityAttr, OpBuilder &builder); }]; let hasCanonicalizer = 1; @@ -161,10 +176,7 @@ def Stream_ResourceDeallocaOp : Stream_Op<"resource.dealloca", [ }]; let arguments = (ins - AnyTypeOf<[ - Stream_StagingResource, - Stream_TransientResource, - ]>:$operand, + Stream_AnyResource:$operand, Stream_Size:$operand_size, Optional:$await_timepoint, OptionalAttr:$affinity @@ -776,7 +788,7 @@ def Stream_TensorImportOp : Stream_PureOp<"tensor.import", [ OptionalAttr:$affinity ); let results = (outs - Stream_ExternalResource:$result + AnyTypeOf<[Stream_AnyStreamResource, Stream_StagingResource]>:$result ); let assemblyFormat = [{ @@ -822,7 +834,7 @@ def Stream_TensorExportOp : Stream_PureOp<"tensor.export", [ }]; let arguments = (ins - Stream_ExternalResource:$source, + AnyTypeOf<[Stream_AnyStreamResource, Stream_StagingResource]>:$source, TypeAttr:$source_encoding, Stream_ShapeDynamicDims:$source_encoding_dims, Stream_Size:$source_size, diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/resource_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/resource_ops.mlir index c42ccfee0a14c..f19f53007d03d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/resource_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/resource_ops.mlir @@ -1,10 +1,10 @@ // RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s // CHECK-LABEL: @resourceAlloc -func.func @resourceAlloc(%arg0: index, %arg1: index) -> (!stream.resource<*>, !stream.resource<*>) { - // CHECK: = stream.resource.alloc uninitialized : !stream.resource<*>{%arg0}, !stream.resource<*>{%arg1} - %0:2 = stream.resource.alloc uninitialized : !stream.resource<*>{%arg0}, !stream.resource<*>{%arg1} - return %0#0, %0#1 : !stream.resource<*>, !stream.resource<*> +func.func @resourceAlloc(%arg0: index) -> !stream.resource<*> { + // CHECK: = stream.resource.alloc uninitialized : !stream.resource<*>{%arg0} + %0 = stream.resource.alloc uninitialized : !stream.resource<*>{%arg0} + return %0 : !stream.resource<*> } // ----- diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel index f3693a8868cf8..9144ae498e37a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel @@ -27,7 +27,6 @@ iree_compiler_cc_library( "LayoutSlices.cpp", "MaterializeBuiltins.cpp", "MaterializeCopyOnWrite.cpp", - "PackAllocations.cpp", "PackConstants.cpp", "PackDispatchOperands.cpp", "PassDetail.h", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt index 11074572de613..80dec36d1ff71 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt @@ -29,7 +29,6 @@ iree_cc_library( "LayoutSlices.cpp" "MaterializeBuiltins.cpp" "MaterializeCopyOnWrite.cpp" - "PackAllocations.cpp" "PackConstants.cpp" "PackDispatchOperands.cpp" "PassDetail.h" diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackAllocations.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackAllocations.cpp deleted file mode 100644 index 4e4ed6246d281..0000000000000 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackAllocations.cpp +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" -#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" -#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" -#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h" -#include "iree/compiler/Dialect/Stream/Transforms/Passes.h" -#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" -#include "iree/compiler/Dialect/Util/IR/UtilOps.h" -#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" -#include "llvm/Support/Debug.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/AsmState.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Pass/Pass.h" - -#define DEBUG_TYPE "iree-stream-pack-allocations" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace Stream { -namespace { - -//===----------------------------------------------------------------------===// -// -iree-stream-pack-allocations -//===----------------------------------------------------------------------===// - -class PackAllocationsPass : public PackAllocationsBase { -public: - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - registry.insert(); - registry.insert(); - } - - void runOnOperation() override { - auto parentOp = getOperation(); - if (!parentOp.getCallableRegion() || - parentOp.getCallableRegion()->empty()) { - return; - } - - // This is pretty lazy: we just turn stream.resource.alloc ops into a - // stream.resource.pack + stream.resource.alloc of a single resource. - // This way we reuse all the resource constraints stuff that the pack op - // provides even though all of the resources we allocate have perfectly - // overlapping lifetime spans. - // - // In the future, we should be doing deeper lifetime analysis here and - // subdividing the allocs based on which resources travel together. We can - // also do things like overlap the lifetime of inputs and outputs to - // execution regions as usually inputs end their lifetime before the outputs - // are produced. In this way we'd use the slice intervals to denote which - // are mutually exclusive. - parentOp.walk([&](IREE::Stream::ResourceAllocOp allocOp) { - // If just one result then ignore (nothing to pack). - if (allocOp.getResults().size() == 1) - return; - auto resourceType = allocOp.getResults().front().getType(); - - // NOTE: this is risky: we are assuming right now that all of the - // allocations will fit within the constraints of the system. This is not - // guaranteed: a very low maximum buffer range may lead to packed slabs - // that are not fully addressable. For now we are processing models with - // small enough workloads and our target devices are relatively lax on - // things so long as we stay under UINT32_MAX boundaries. - - // All slices are 0-0 (overlapping). - size_t sliceCount = allocOp.getResults().size(); - SmallVector lifetimeIntervals(sliceCount * 2, 0); - - OpBuilder builder(allocOp); - auto indexType = builder.getIndexType(); - SmallVector packedOffsetTypes(sliceCount, indexType); - auto packOp = builder.create( - allocOp.getLoc(), indexType, packedOffsetTypes, /*offset=*/nullptr, - builder.getIndexArrayAttr(lifetimeIntervals), - allocOp.getStorageSizes(), allocOp.getAffinityAttr()); - - // Change the alloc to build just a single resource. - auto newOp = builder.create( - allocOp.getLoc(), resourceType, packOp.getTotalLength(), - allocOp.getUninitializedAttr(), allocOp.getAffinityAttr()); - auto slab = newOp.getResults().front(); - auto slabSize = packOp.getTotalLength(); - - // Replace all resources with subviews into the new slab. - for (auto [originalValue, subviewOffset, subviewLength] : - llvm::zip_equal(allocOp.getResults(), packOp.getPackedOffsets(), - allocOp.getStorageSizes())) { - auto subviewOp = builder.create( - allocOp.getLoc(), slab, slabSize, subviewOffset, subviewLength); - originalValue.replaceAllUsesWith(subviewOp.getResult()); - } - - allocOp.erase(); - }); - } -}; - -} // namespace - -std::unique_ptr> -createPackAllocationsPass() { - return std::make_unique(); -} - -} // namespace Stream -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp index 74ad0dd93266d..87c02fa8ef450 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp @@ -262,7 +262,7 @@ static TimepointResource buildFileRead( auto zeroI64 = builder.create(storageResource.loc, 0, 64); auto readOp = builder.create( - storageResource.loc, fileOp.getResult(), zeroI64, allocOp.getResult(0), + storageResource.loc, fileOp.getResult(), zeroI64, allocOp.getResult(), allocOp.getResultSize(0), indexSet.get(0), storageResourceSize, awaitTimepoint, affinityAttr); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp index 797ac47732a56..fbfa6631ec374 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp @@ -206,9 +206,6 @@ void buildStreamCmdPassPipeline(OpPassManager &passManager, // storage buffers and upload logic. .addPass(IREE::Stream::createPackConstantsPass) - // Pack fused allocations based on lifetime. - .addPass(IREE::Stream::createPackAllocationsPass) - // Layout packed slices to emit the arithmetic required for all resource // offsets. This enables us to propagate the subviews across the program // below. diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h index 324937fd8d549..ca8c861f85eae 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h @@ -128,7 +128,6 @@ std::unique_ptr> createElideTimepointsPass(); std::unique_ptr> createScheduleAllocationPass(); std::unique_ptr> createPackConstantsPass(); -std::unique_ptr> createPackAllocationsPass(); std::unique_ptr> createLayoutSlicesPass(); //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td index c11e3959e6aa2..1a19d4136e928 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td @@ -137,14 +137,6 @@ def PackConstants : }]; } -def PackAllocations : - InterfacePass<"iree-stream-pack-allocations", "mlir::CallableOpInterface"> { - let summary = "Packs fused allocations based on lifetime."; - let constructor = [{ - mlir::iree_compiler::IREE::Stream::createPackAllocationsPass() - }]; -} - def LayoutSlices : InterfacePass<"iree-stream-layout-slices", "mlir::CallableOpInterface"> { let summary = "Lays out packed slices and produces arithmetic required for all offsets."; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp index 59c104ae60fa6..e2f3e976e06c8 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp @@ -1581,25 +1581,28 @@ allocateExecutionRegion(IREE::Stream::AsyncExecuteOp executeOp) { // TODO(benvanik): change this to an alloca. We may need a higher-level // analysis to decide when to deallocate, or just leave it to be deallocated // as part of garbage collection. - auto allocOp = externalBuilder.create( - externalBuilder.getFusedLoc(reservationSet.reservationLocs), - reservationSet.reservationTypes, reservationSet.reservationSizes, - /*uninitialized=*/externalBuilder.getUnitAttr(), - executeOp.getAffinityAttr()); + auto timepointType = externalBuilder.getType(); + auto [allocaOp, suballocations] = + IREE::Stream::ResourceAllocaOp::createSuballocations( + timepointType, reservationSet.reservationTypes.front(), + reservationSet.reservationLocs, reservationSet.reservationSizes, + executeOp.getAwaitTimepoint(), executeOp.getAffinityAttr(), + externalBuilder); + newAwaitTimepoints.push_back(allocaOp.getResultTimepoint()); auto asmState = getRootAsmState(executeOp->getParentOp()); LLVM_DEBUG({ llvm::dbgs() << " + alloc for result reservation set: "; - allocOp.print(llvm::dbgs(), *asmState); + allocaOp.print(llvm::dbgs(), *asmState); llvm::dbgs() << ":\n"; }); - for (auto [reservation, allocResult] : - llvm::zip_equal(reservationSet.reservations, allocOp.getResults())) { - newOperands.push_back(allocResult); + for (auto [reservation, suballocation] : + llvm::zip_equal(reservationSet.reservations, suballocations)) { + newOperands.push_back(suballocation); newOperandSizes.push_back(reservation.resultSize); resultReplacements.push_back( - std::make_pair(reservation.result, allocResult)); + std::make_pair(reservation.result, suballocation)); // Insert entry arg for the new operand tied all the way to the yield. auto arg = diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel index d0a3103a3a405..965b61f2cc632 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel @@ -34,7 +34,6 @@ iree_lit_test_suite( "layout_slices.mlir", "materialize_builtins.mlir", "materialize_copy_on_write.mlir", - "pack_allocations.mlir", "pack_constants.mlir", "pack_dispatch_operands.mlir", "propagate_subviews.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt index 2a1d56a054321..2e2294a000545 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt @@ -32,7 +32,6 @@ iree_lit_test_suite( "layout_slices.mlir" "materialize_builtins.mlir" "materialize_copy_on_write.mlir" - "pack_allocations.mlir" "pack_constants.mlir" "pack_dispatch_operands.mlir" "propagate_subviews.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_allocations.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_allocations.mlir deleted file mode 100644 index e98e5ae9533fa..0000000000000 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_allocations.mlir +++ /dev/null @@ -1,38 +0,0 @@ -// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-stream-pack-allocations))' %s | FileCheck %s - -// CHECK-LABEL: @packAllocations -// CHECK-SAME: (%[[SIZE_A:.+]]: index, %[[SIZE_B:.+]]: index) -func.func @packAllocations(%size_a: index, %size_b: index) { - // CHECK: %[[SLICES:.+]]:3 = stream.resource.pack slices({ - // CHECK-NEXT: [0, 0] = %[[SIZE_A]], - // CHECK-NEXT: [0, 0] = %[[SIZE_B]] - // CHECK-NEXT: }) : index - // CHECK: %[[ALLOC:.+]] = stream.resource.alloc uninitialized : !stream.resource{%[[SLICES]]#0} - %0:2 = stream.resource.alloc uninitialized : - !stream.resource{%size_a}, - !stream.resource{%size_b} - - // CHECK: %[[SLICE_A:.+]] = stream.resource.subview %[[ALLOC]][%[[SLICES]]#1] - // CHECK-SAME: !stream.resource{%[[SLICES]]#0} -> !stream.resource{%[[SIZE_A]]} - // CHECK: %[[SLICE_B:.+]] = stream.resource.subview %[[ALLOC]][%[[SLICES]]#2] - // CHECK-SAME: !stream.resource{%[[SLICES]]#0} -> !stream.resource{%[[SIZE_B]]} - - // CHECK: util.optimization_barrier %[[SLICE_A]] - util.optimization_barrier %0#0 : !stream.resource - // CHECK: util.optimization_barrier %[[SLICE_B]] - util.optimization_barrier %0#1 : !stream.resource - return -} - -// ----- - -// CHECK-LABEL: @packEmpty -func.func @packEmpty() { - // CHECK: %[[ALLOC:.+]] = stream.resource.alloc : !stream.resource{%c0} - %c0 = arith.constant 0 : index - %0 = stream.resource.alloc : !stream.resource{%c0} - - // CHECK: util.optimization_barrier %[[ALLOC]] - util.optimization_barrier %0 : !stream.resource - return -} diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir index 08477f30547b6..1398f799813df 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir @@ -183,10 +183,16 @@ func.func @aliasPropagation(%operand: !stream.resource, %size: index, func.func @producedResults(%size0: index, %size1: index) { %c254_i32 = arith.constant 254 : i32 %c255_i32 = arith.constant 255 : i32 - // CHECK: %[[ALLOC_RETS:.+]]:2 = stream.resource.alloc uninitialized : !stream.resource{%[[SIZE0]]}, !stream.resource{%[[SIZE1]]} - // CHECK: %[[TIMEPOINT:.+]] = stream.cmd.execute - // CHECK-SAME: with(%[[ALLOC_RETS]]#0 as %[[CAPTURE0:.+]]: !stream.resource{%[[SIZE0]]}, - // CHECK-SAME: %[[ALLOC_RETS]]#1 as %[[CAPTURE1:.+]]: !stream.resource{%[[SIZE1]]}) + // CHECK: %[[PACK:.+]]:3 = stream.resource.pack slices({ + // CHECK-NEXT: [0, 0] = %[[SIZE0]], + // CHECK-NEXT: [0, 0] = %[[SIZE1]] + // CHECK-NEXT: }) : index + // CHECK: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized : !stream.resource{%[[PACK]]#0} + // CHECK: %[[SUBALLOCA0:.+]] = stream.resource.subview %[[ALLOCA]][%[[PACK]]#1] : !stream.resource{%[[PACK]]#0} -> !stream.resource{%[[SIZE0]]} + // CHECK: %[[SUBALLOCA1:.+]] = stream.resource.subview %[[ALLOCA]][%[[PACK]]#2] : !stream.resource{%[[PACK]]#0} -> !stream.resource{%[[SIZE1]]} + // CHECK: %[[EXECUTE_TIMEPOINT:.+]] = stream.cmd.execute await(%[[ALLOCA_TIMEPOINT]]) + // CHECK-SAME: with(%[[SUBALLOCA0]] as %[[CAPTURE0:.+]]: !stream.resource{%[[SIZE0]]}, + // CHECK-SAME: %[[SUBALLOCA1]] as %[[CAPTURE1:.+]]: !stream.resource{%[[SIZE1]]}) %results:2, %result_timepoint = stream.async.execute with() -> (!stream.resource{%size0}, !stream.resource{%size1}) { // CHECK: stream.cmd.fill %c254_i32, %[[CAPTURE0]] %0 = stream.async.splat %c254_i32 : i32 -> !stream.resource{%size0} @@ -194,11 +200,11 @@ func.func @producedResults(%size0: index, %size1: index) { %1 = stream.async.splat %c255_i32 : i32 -> !stream.resource{%size1} stream.yield %0, %1 : !stream.resource{%size0}, !stream.resource{%size1} } => !stream.timepoint - // CHECK: util.optimization_barrier %[[TIMEPOINT]] + // CHECK: util.optimization_barrier %[[EXECUTE_TIMEPOINT]] util.optimization_barrier %result_timepoint : !stream.timepoint - // CHECK: util.optimization_barrier %[[ALLOC_RETS]]#0 + // CHECK: util.optimization_barrier %[[SUBALLOCA0]] util.optimization_barrier %results#0 : !stream.resource - // CHECK: util.optimization_barrier %[[ALLOC_RETS]]#1 + // CHECK: util.optimization_barrier %[[SUBALLOCA1]] util.optimization_barrier %results#1 : !stream.resource return } @@ -248,10 +254,10 @@ func.func @concurrentRegions(%operand: !stream.resource, %size: index %c128 = arith.constant 128 : index %c254_i32 = arith.constant 254 : i32 %c255_i32 = arith.constant 255 : i32 - // CHECK: %[[ALLOC:.+]] = stream.resource.alloc uninitialized : !stream.resource{%[[SIZE]]} - // CHECK: stream.cmd.execute + // CHECK: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized : !stream.resource{%[[SIZE]]} + // CHECK: stream.cmd.execute await(%[[ALLOCA_TIMEPOINT]]) // CHECK-SAME: with(%[[OPERAND]] as %[[OPERAND_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}, - // CHECK-SAME: %[[ALLOC]] as %[[ALLOC_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}) + // CHECK-SAME: %[[ALLOCA]] as %[[ALLOC_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}) %results:2, %result_timepoint = stream.async.execute with(%operand as %capture: !stream.resource{%size}) -> (!stream.resource{%size}, !stream.resource{%size}) { // CHECK: stream.cmd.concurrent %0:2 = stream.async.concurrent with(%capture as %concurrent_capture: !stream.resource{%size}) -> (%capture as !stream.resource{%size}, !stream.resource{%size}) { @@ -265,7 +271,7 @@ func.func @concurrentRegions(%operand: !stream.resource, %size: index } => !stream.timepoint // CHECK: util.optimization_barrier %[[OPERAND]] util.optimization_barrier %results#0 : !stream.resource - // CHECK: util.optimization_barrier %[[ALLOC]] + // CHECK: util.optimization_barrier %[[ALLOCA]] util.optimization_barrier %results#1 : !stream.resource return } @@ -276,14 +282,15 @@ func.func @concurrentRegions(%operand: !stream.resource, %size: index // CHECK-SAME: (%[[SIZE:.+]]: index) func.func @applyAsyncSplatOp(%size: index) { %c255_i32 = arith.constant 255 : i32 - // CHECK: %[[ALLOC:.+]] = stream.resource.alloc uninitialized : !stream.resource{%[[SIZE]]} - // CHECK: stream.cmd.execute with(%[[ALLOC]] as %[[CAPTURE:.+]]: !stream.resource{%[[SIZE]]}) + // CHECK: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized : !stream.resource{%[[SIZE]]} + // CHECK: stream.cmd.execute await(%[[ALLOCA_TIMEPOINT]]) + // CHECK-SAME: with(%[[ALLOCA]] as %[[CAPTURE:.+]]: !stream.resource{%[[SIZE]]}) %result, %result_timepoint = stream.async.execute with() -> (!stream.resource{%size}) { // CHECK: stream.cmd.fill %c255_i32, %[[CAPTURE]][%c0 for %[[SIZE]]] : i32 -> !stream.resource{%[[SIZE]]} %0 = stream.async.splat %c255_i32 : i32 -> !stream.resource{%size} stream.yield %0 : !stream.resource{%size} } => !stream.timepoint - // CHECK: util.optimization_barrier %[[ALLOC]] + // CHECK: util.optimization_barrier %[[ALLOCA]] util.optimization_barrier %result : !stream.resource return } @@ -293,17 +300,17 @@ func.func @applyAsyncSplatOp(%size: index) { // CHECK-LABEL: @applyAsyncCloneOp // CHECK-SAME: (%[[OPERAND:.+]]: !stream.resource, %[[SIZE:.+]]: index) func.func @applyAsyncCloneOp(%operand: !stream.resource, %size: index) { - // CHECK: %[[ALLOC:.+]] = stream.resource.alloc uninitialized : !stream.resource{%[[SIZE]]} - // CHECK: stream.cmd.execute + // CHECK: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized : !stream.resource{%[[SIZE]]} + // CHECK: stream.cmd.execute await(%[[ALLOCA_TIMEPOINT]]) // CHECK-SAME: with(%[[OPERAND]] as %[[OPERAND_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}, - // CHECK-SAME: %[[ALLOC]] as %[[ALLOC_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}) + // CHECK-SAME: %[[ALLOCA]] as %[[ALLOC_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}) %result, %result_timepoint = stream.async.execute with(%operand as %capture: !stream.resource{%size}) -> !stream.resource{%size} { // CHECK: stream.cmd.copy %[[OPERAND_CAPTURE]][%c0], %[[ALLOC_CAPTURE]][%c0], %[[SIZE]] // CHECK-SAME: : !stream.resource{%[[SIZE]]} -> !stream.resource{%[[SIZE]]} %0 = stream.async.clone %capture : !stream.resource{%size} -> !stream.resource{%size} stream.yield %0 : !stream.resource{%size} } => !stream.timepoint - // CHECK: util.optimization_barrier %[[ALLOC]] + // CHECK: util.optimization_barrier %[[ALLOCA]] util.optimization_barrier %result : !stream.resource return } @@ -319,17 +326,17 @@ func.func @applyAsyncSliceOp(%operand: !stream.resource, %size: index %c16 = arith.constant 16 : index %c128 = arith.constant 128 : index %c144 = arith.constant 144 : index - // CHECK: %[[ALLOC:.+]] = stream.resource.alloc uninitialized : !stream.resource{%c128} - // CHECK: stream.cmd.execute + // CHECK: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized : !stream.resource{%c128} + // CHECK: stream.cmd.execute await(%[[ALLOCA_TIMEPOINT]]) // CHECK-SAME: with(%[[OPERAND]] as %[[OPERAND_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}, - // CHECK-SAME: %[[ALLOC]] as %[[ALLOC_CAPTURE:.+]]: !stream.resource{%c128}) + // CHECK-SAME: %[[ALLOCA]] as %[[ALLOC_CAPTURE:.+]]: !stream.resource{%c128}) %result, %result_timepoint = stream.async.execute with(%operand as %capture: !stream.resource{%size}) -> !stream.resource{%c128} { // CHECK: stream.cmd.copy %[[OPERAND_CAPTURE]][%c16], %[[ALLOC_CAPTURE]][%c0], %c128 // CHECK-SAME: : !stream.resource{%[[SIZE]]} -> !stream.resource{%c128} %0 = stream.async.slice %capture[%c16 to %c144] : !stream.resource{%size} -> !stream.resource{%c128} stream.yield %0 : !stream.resource{%c128} } => !stream.timepoint - // CHECK: util.optimization_barrier %[[ALLOC]] + // CHECK: util.optimization_barrier %[[ALLOCA]] util.optimization_barrier %result : !stream.resource return } @@ -477,17 +484,17 @@ func.func @applyAsyncCollectiveOpOutOfPlace(%channel: !stream.channel, %send: !s // CHECK-LABEL: @applyAsyncTransferOp // CHECK-SAME: (%[[OPERAND:.+]]: !stream.resource, %[[SIZE:.+]]: index) func.func @applyAsyncTransferOp(%operand: !stream.resource, %size: index) { - // CHECK: %[[ALLOC:.+]] = stream.resource.alloc uninitialized : !stream.resource{%[[SIZE]]} - // CHECK: stream.cmd.execute + // CHECK: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized : !stream.resource{%[[SIZE]]} + // CHECK: stream.cmd.execute await(%[[ALLOCA_TIMEPOINT]]) // CHECK-SAME: with(%[[OPERAND]] as %[[OPERAND_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}, - // CHECK-SAME: %[[ALLOC]] as %[[ALLOC_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}) + // CHECK-SAME: %[[ALLOCA]] as %[[ALLOCA_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}) %result, %result_timepoint = stream.async.execute with(%operand as %capture: !stream.resource{%size}) -> !stream.resource{%size} { - // CHECK: stream.cmd.copy %[[OPERAND_CAPTURE]][%c0], %[[ALLOC_CAPTURE]][%c0], %[[SIZE]] + // CHECK: stream.cmd.copy %[[OPERAND_CAPTURE]][%c0], %[[ALLOCA_CAPTURE]][%c0], %[[SIZE]] // CHECK-SAME: : !stream.resource{%[[SIZE]]} -> !stream.resource{%[[SIZE]]} %0 = stream.async.transfer %capture : !stream.resource{%size} -> !stream.resource{%size} stream.yield %0 : !stream.resource{%size} } => !stream.timepoint - // CHECK: util.optimization_barrier %[[ALLOC]] + // CHECK: util.optimization_barrier %[[ALLOCA]] util.optimization_barrier %result : !stream.resource return } @@ -500,14 +507,14 @@ func.func @applyAsyncDispatchOp(%operand: !stream.resource, %size: in %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index - // CHECK: %[[ALLOC:.+]] = stream.resource.alloc uninitialized : !stream.resource{%[[SIZE]]} - // CHECK: %[[TIMEPOINT:.+]] = stream.cmd.execute + // CHECK: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized : !stream.resource{%[[SIZE]]} + // CHECK: %[[TIMEPOINT:.+]] = stream.cmd.execute await(%[[ALLOCA_TIMEPOINT]]) // CHECK-SAME: with(%[[OPERAND]] as %[[OPERAND_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}, - // CHECK-SAME: %[[ALLOC]] as %[[ALLOC_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}) + // CHECK-SAME: %[[ALLOCA]] as %[[ALLOCA_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}) %results:2, %result_timepoint = stream.async.execute with(%operand as %capture: !stream.resource{%size}) -> (%operand as !stream.resource{%size}, !stream.resource{%size}) { // CHECK-NEXT: stream.cmd.dispatch @executable::@dispatch[%c1, %c1, %c1](%c4 : index) { // CHECK-NEXT: rw %[[OPERAND_CAPTURE]][%[[OFFSET]] for %[[LENGTH]]] : !stream.resource{%[[SIZE]]}, - // CHECK-NEXT: wo %[[ALLOC_CAPTURE]][%c0{{[_0-9]*}} for %[[SIZE]]] : !stream.resource{%[[SIZE]]} + // CHECK-NEXT: wo %[[ALLOCA_CAPTURE]][%c0{{[_0-9]*}} for %[[SIZE]]] : !stream.resource{%[[SIZE]]} // CHECK-NEXT: } %0:2 = stream.async.dispatch @executable::@dispatch[%c1, %c1, %c1](%capture[%offset to %end for %length], %c4) : (!stream.resource{%size}, index) -> (%capture{%size}, !stream.resource{%size}) stream.yield %0#0, %0#1 : !stream.resource{%size}, !stream.resource{%size} @@ -516,7 +523,7 @@ func.func @applyAsyncDispatchOp(%operand: !stream.resource, %size: in util.optimization_barrier %result_timepoint : !stream.timepoint // CHECK: util.optimization_barrier %[[OPERAND]] util.optimization_barrier %results#0 : !stream.resource - // CHECK: util.optimization_barrier %[[ALLOC]] + // CHECK: util.optimization_barrier %[[ALLOCA]] util.optimization_barrier %results#1 : !stream.resource return } @@ -571,12 +578,12 @@ func.func @applyAsyncCallOp(%operand: !stream.resource, %size: index, %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index - // CHECK: %[[ALLOC:.+]] = stream.resource.alloc uninitialized : !stream.resource{%[[SIZE]]} - // CHECK: %[[TIMEPOINT:.+]] = stream.cmd.execute + // CHECK: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized : !stream.resource{%[[SIZE]]} + // CHECK: %[[TIMEPOINT:.+]] = stream.cmd.execute await(%[[ALLOCA_TIMEPOINT]]) // CHECK-SAME: with(%[[OPERAND]] as %[[OPERAND_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}, - // CHECK-SAME: %[[ALLOC]] as %[[ALLOC_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}) + // CHECK-SAME: %[[ALLOCA]] as %[[ALLOCA_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}) %results:2, %result_timepoint = stream.async.execute with(%operand as %capture: !stream.resource{%size}) -> (%operand as !stream.resource{%size}, !stream.resource{%size}) { - // CHECK-NEXT: stream.cmd.call @asyncExtern(rw %[[OPERAND_CAPTURE]][%[[OFFSET]] for %[[LENGTH]]], %c4, wo %[[ALLOC_CAPTURE]][%c0{{[_0-9]*}} for %[[SIZE]]]) : + // CHECK-NEXT: stream.cmd.call @asyncExtern(rw %[[OPERAND_CAPTURE]][%[[OFFSET]] for %[[LENGTH]]], %c4, wo %[[ALLOCA_CAPTURE]][%c0{{[_0-9]*}} for %[[SIZE]]]) : // CHECK-SAME: (!stream.resource{%[[SIZE]]}, index, !stream.resource{%[[SIZE]]}) -> () %0:2 = stream.async.call @asyncExtern(%capture[%offset to %end for %length], %c4) : (!stream.resource{%size}, index) -> (%capture{%size}, !stream.resource{%size}) stream.yield %0#0, %0#1 : !stream.resource{%size}, !stream.resource{%size} @@ -585,7 +592,7 @@ func.func @applyAsyncCallOp(%operand: !stream.resource, %size: index, util.optimization_barrier %result_timepoint : !stream.timepoint // CHECK: util.optimization_barrier %[[OPERAND]] util.optimization_barrier %results#0 : !stream.resource - // CHECK: util.optimization_barrier %[[ALLOC]] + // CHECK: util.optimization_barrier %[[ALLOCA]] util.optimization_barrier %results#1 : !stream.resource return } diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp index 5db024976658b..fda46c4648720 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp @@ -70,16 +70,11 @@ struct ResourceAllocOpPattern Value minAlignment = rewriter.create(allocOp.getLoc(), 64); - SmallVector results; - for (auto [resourceResult, storageSize] : - llvm::zip_equal(allocOp.getResults(), allocOp.getStorageSizes())) { - auto allocateOp = rewriter.create( - allocOp.getLoc(), deviceBufferType, hostBufferType, minAlignment, - storageSize); - results.push_back(allocateOp.getResult()); - } + auto allocateOp = rewriter.create( + allocOp.getLoc(), deviceBufferType, hostBufferType, minAlignment, + adaptor.getStorageSize()); + rewriter.replaceOp(allocOp, allocateOp.getResult()); - rewriter.replaceOp(allocOp, results); return success(); } }; diff --git a/experimental/cuda2/cuda_buffer.c b/experimental/cuda2/cuda_buffer.c index ff5a254e5d253..d1d017fed2024 100644 --- a/experimental/cuda2/cuda_buffer.c +++ b/experimental/cuda2/cuda_buffer.c @@ -43,6 +43,13 @@ iree_status_t iree_hal_cuda2_buffer_wrap( void* host_ptr, iree_hal_buffer_release_callback_t release_callback, iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer) { IREE_ASSERT_ARGUMENT(out_buffer); + if (!host_ptr && iree_any_bit_set(allowed_usage, + IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT | + IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "mappable buffers require host pointers"); + } + IREE_TRACE_ZONE_BEGIN(z0); iree_hal_cuda2_buffer_t* buffer = NULL; @@ -95,6 +102,7 @@ static iree_status_t iree_hal_cuda2_buffer_map_range( ? IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT : IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)); + IREE_ASSERT(buffer->host_ptr, "mappable buffers require host pointers"); uint8_t* data_ptr = (uint8_t*)(buffer->host_ptr) + local_byte_offset; // If we mapped for discard scribble over the bytes. This is not a mandated // behavior but it will make debugging issues easier. Alternatively for diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_buffer.c b/runtime/src/iree/hal/drivers/cuda/cuda_buffer.c index bcb1ad742536d..f9f33b4f10688 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_buffer.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_buffer.c @@ -43,6 +43,13 @@ iree_status_t iree_hal_cuda_buffer_wrap( void* host_ptr, iree_hal_buffer_release_callback_t release_callback, iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer) { IREE_ASSERT_ARGUMENT(out_buffer); + if (!host_ptr && iree_any_bit_set(allowed_usage, + IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT | + IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "mappable buffers require host pointers"); + } + IREE_TRACE_ZONE_BEGIN(z0); iree_hal_cuda_buffer_t* buffer = NULL; @@ -93,6 +100,7 @@ static iree_status_t iree_hal_cuda_buffer_map_range( ? IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT : IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)); + IREE_ASSERT(buffer->host_ptr, "mappable buffers require host pointers"); uint8_t* data_ptr = (uint8_t*)(buffer->host_ptr) + local_byte_offset; // If we mapped for discard scribble over the bytes. This is not a mandated // behavior but it will make debugging issues easier. Alternatively for From 417f32380506689989dd4c35a019f4a9aae4ad75 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 12 Jun 2023 10:47:45 -0700 Subject: [PATCH 41/44] Switching external resources to be device-local only. Previously all external resources (results returned by an invocation) were made host-visible and mappable and this prevented the use of queue-ordered allocations in CUDA as memory pools cannot service memory with associated host pointers. Depending on device the host-visible memory could also be much slower to access (or have more potential pitfalls with page management) vs pinned device-local memory and this got worse once we started doing more dispatches in-place on the results. Now all external buffers are by default allocated as device-local. Users will need to manually stage the buffers and otherwise they'll remain on-device. For externalized state this is a good thing as it means we'll keep state on device automatically. A temporary flag has been added to revert to the old mappable behavior with `--iree-stream-external-resources-mappable=true`. Note that some devices (like CPU) will always allow mapping even if not requested and users can avoid the copies by checking before performing the transfers. --- .../HAL/Conversion/ConversionTarget.cpp | 36 ---- .../Dialect/HAL/Conversion/ConversionTarget.h | 42 ---- .../HAL/Conversion/StreamToHAL/Patterns.cpp | 35 ++-- .../HAL/Transforms/test/convert_to_hal.mlir | 4 +- .../Dialect/Stream/Analysis/ResourceUsage.cpp | 55 +++++- .../Conversion/FlowToStream/Patterns.cpp | 8 +- .../Modules/Check/Conversion/BUILD.bazel | 1 + .../Modules/Check/Conversion/CMakeLists.txt | 1 + .../Check/Conversion/ConversionPatterns.cpp | 89 ++++++++- .../compiler/Modules/Check/IR/BUILD.bazel | 1 + .../compiler/Modules/Check/IR/CMakeLists.txt | 1 + .../Modules/Check/IR/CheckDialect.cpp | 3 + .../compiler/Modules/Check/IR/CheckOps.cpp | 2 +- .../compiler/Modules/Check/IR/CheckOps.td | 51 +++-- .../compiler/Modules/Check/check.imports.mlir | 3 + experimental/cuda2/cuda_device.c | 2 +- .../src/iree/hal/drivers/cuda/cuda_device.c | 2 +- runtime/src/iree/modules/check/check_test.cc | 3 + runtime/src/iree/modules/check/module.cc | 106 +++++++++- .../src/iree/modules/check/test/success.mlir | 3 +- runtime/src/iree/modules/hal/types.c | 6 +- runtime/src/iree/tooling/run_module.c | 16 ++ runtime/src/iree/tooling/vm_util.c | 181 ++++++++++++++++++ runtime/src/iree/tooling/vm_util.h | 10 + tools/BUILD.bazel | 1 + tools/CMakeLists.txt | 1 + tools/iree-e2e-matmul-test.c | 51 ++--- tools/iree-run-trace-main.c | 15 ++ 28 files changed, 571 insertions(+), 158 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp index 02785f5855af6..c9fd8df92ac3a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp @@ -9,8 +9,6 @@ #include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -43,39 +41,5 @@ HALConversionTarget::HALConversionTarget(MLIRContext *context, }); } -// static -LogicalResult HALConversionTarget::applyDefaultBufferRewrite( - Operation *srcOp, ValueRange operands, StringRef dstOpName, - TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { - OperationState state{srcOp->getLoc(), dstOpName}; - state.addAttributes(srcOp->getAttrs()); - - for (auto [srcOperand, dstOperand] : - llvm::zip_equal(srcOp->getOperands(), operands)) { - // Check that any type that should have been mapped to buffer view was. - // This is just to catch conflicts in type conversions that may sneak in - // during development. - assert( - (!HALTypeConverter::shouldConvertToBufferView(srcOperand.getType()) || - dstOperand.getType().isa()) && - "expect that tensors have been mapped to buffer views"); - state.addOperands({dstOperand}); - } - for (auto resultType : srcOp->getResultTypes()) { - if (HALTypeConverter::shouldConvertToBufferView(resultType)) { - state.addTypes(IREE::HAL::BufferViewType::get(rewriter.getContext())); - } else { - // Normal pass-through result. - if (failed(typeConverter.convertType(resultType, state.types))) { - return failure(); - } - } - } - - auto *dstOp = rewriter.create(state); - rewriter.replaceOp(srcOp, dstOp->getResults()); - return success(); -} - } // namespace iree_compiler } // namespace mlir diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h index b41dd1f1a5e67..fd3d4898682db 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h @@ -8,7 +8,6 @@ #define IREE_COMPILER_DIALECT_HAL_CONVERSION_CONVERSIONTARGET_H_ #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" @@ -22,47 +21,6 @@ namespace iree_compiler { class HALConversionTarget : public ConversionTarget { public: HALConversionTarget(MLIRContext *context, TypeConverter &typeConverter); - - // Attempts to rewrite an op that may use tensor values into an op using HAL - // buffers. See HALOpConversion for more information. - static LogicalResult - applyDefaultBufferRewrite(Operation *srcOp, ValueRange operands, - StringRef dstOpName, TypeConverter &typeConverter, - ConversionPatternRewriter &rewriter); -}; - -// HAL tensor-to-buffer conversion utility. -// This can be used by dialects to model custom op conversion from a dialect -// that uses the MLIR tensor type to the IREE HAL buffer type. At this point -// during conversion the source values will be TensorType and the target values -// will be IREE::HAL::BufferTypes. Any static information available about the -// tensor (such as static dimensions, element type, layout, etc) are extracted -// here and lowered as expanded values. -// -// The ABI is currently very basic and will change with the introduction of more -// dynamic shape logic. -// -// Source: -// my.tensor_op(%arg0 : tensor<2x4xf32>) -// Target: -// %arg0_view = hal.buffer_view.create %arg0, ... -// my.buffer_op(%arg0_view : !hal.buffer_view) -template -class HALOpConversion : public OpConversionPattern { -public: - HALOpConversion(MLIRContext *context, TypeConverter &typeConverter) - : OpConversionPattern(context), typeConverter(typeConverter) {} - - LogicalResult - matchAndRewrite(SRC srcOp, typename SRC::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - return HALConversionTarget::applyDefaultBufferRewrite( - srcOp, adaptor.getOperands(), DST::getOperationName(), typeConverter, - rewriter); - } - -protected: - TypeConverter &typeConverter; }; } // namespace iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 8d31a02822972..16e1b468c7ebb 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Dialect/Stream/IR/StreamOps.h" #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -23,6 +24,14 @@ namespace mlir { namespace iree_compiler { +static llvm::cl::opt clExternalResourcesMappable( + "iree-stream-external-resources-mappable", + llvm::cl::desc("Allocates external resources as host-visible and mappable. " + "This can degrade performance and introduce allocation " + "overhead and staging buffers for readback on the host " + "should be managed by the calling application instead."), + llvm::cl::init(false)); + namespace { static Value lookupDeviceFor(Operation *op, OpBuilder &builder) { @@ -263,17 +272,21 @@ deriveAllowedResourceBufferBits(Location loc, default: break; case IREE::Stream::Lifetime::External: - // #yolo; these come from/go to outside the program. - // Today we assume they are device-local|host-visible just for - // practical purposes but that does not have to be true. We really - // want this to be something we analyze and handle on the edges - // (transferring devices/etc if needed). - memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal | - IREE::HAL::MemoryTypeBitfield::HostVisible; - // NOTE: we may not map it but users may after they get them back. - // Another reason we should annotate this - having a buffer be - // mappable is potentially expensive (may get a 2nd copy in memory!). - bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Mapping; + if (clExternalResourcesMappable) { + // #yolo; these come from/go to outside the program. + // Today we assume they are device-local|host-visible just for + // practical purposes but that does not have to be true. We really + // want this to be something we analyze and handle on the edges + // (transferring devices/etc if needed). + memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal | + IREE::HAL::MemoryTypeBitfield::HostVisible; + // NOTE: we may not map it but users may after they get them back. + // Another reason we should annotate this - having a buffer be + // mappable is potentially expensive (may get a 2nd copy in memory!). + bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Mapping; + } else { + memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal; + } break; } return success(); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir index 2aca6c590a142..d45978d6c2730 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir @@ -80,8 +80,8 @@ module attributes {hal.device.targets = [#device_target_cpu]} { %arg1_resource = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<4xf32> in !stream.resource{%c16} // CHECK: %[[RESULT_BUFFER:.+]] = hal.allocator.allocate<%[[ALLOCATOR]] : !hal.allocator> - // CHECK-SAME: type("HostVisible|DeviceVisible|DeviceLocal") - // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}Mapping{{.+}}") + // CHECK-SAME: type("DeviceVisible|DeviceLocal") + // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") // CHECK-SAME: : !hal.buffer{%c16} %result_resource = stream.resource.alloc uninitialized : !stream.resource{%c16} diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp index e7e2bedbf4b64..7254cee0bf995 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp @@ -307,7 +307,27 @@ class ValueResourceUsage : public AbstractResourceUsage { getState() ^= targetUsage.getState(); }) .Case([&](IREE::Stream::TensorImportOp op) { - removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL); + auto targetType = + llvm::cast(op.getResult().getType()); + switch (targetType.getLifetime()) { + default: + case IREE::Stream::Lifetime::External: + removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL); + break; + case IREE::Stream::Lifetime::Staging: + removeAssumedBits(NOT_MUTATED | NOT_STAGING_READ | + NOT_STAGING_WRITE); + break; + case IREE::Stream::Lifetime::Transient: + removeAssumedBits(NOT_MUTATED); + break; + case IREE::Stream::Lifetime::Variable: + removeAssumedBits(NOT_MUTATED | NOT_GLOBAL_READ | NOT_GLOBAL_WRITE); + break; + case IREE::Stream::Lifetime::Constant: + removeAssumedBits(NOT_CONSTANT); + break; + } auto &resultUsage = solver.getElementFor( *this, Position::forValue(op.getResult()), DFX::Resolution::REQUIRED); @@ -497,7 +517,6 @@ class ValueResourceUsage : public AbstractResourceUsage { *this, Position::forValue(op->getOperand(operandIdx)), DFX::Resolution::REQUIRED); getState() ^= operandUsage.getState(); - auto &beforeUsage = solver.getElementFor( *this, Position::forValue(op.getBeforeBody()->getArgument(operandIdx)), @@ -510,13 +529,11 @@ class ValueResourceUsage : public AbstractResourceUsage { *this, Position::forValue(op->getOperand(operandIdx)), DFX::Resolution::REQUIRED); getState() ^= operandUsage.getState(); - auto &parentUsage = solver.getElementFor( *this, Position::forValue(op->getParentOp()->getResult(operandIdx - 1)), DFX::Resolution::REQUIRED); getState() ^= parentUsage.getState(); - if (auto whileOp = dyn_cast_or_null(op->getParentOp())) { auto value = Position::forValue( @@ -532,14 +549,12 @@ class ValueResourceUsage : public AbstractResourceUsage { *this, Position::forValue(op->getOperand(operandIdx)), DFX::Resolution::REQUIRED); getState() ^= operandUsage.getState(); - auto &parentUsage = solver.getElementFor( *this, Position::forValue(op->getParentOp()->getResult(operandIdx)), DFX::Resolution::REQUIRED); getState() ^= parentUsage.getState(); } - if (auto whileOp = dyn_cast_or_null(op->getParentOp())) { auto value = @@ -589,7 +604,33 @@ class ValueResourceUsage : public AbstractResourceUsage { removeAssumedBits(NOT_INDIRECT | NOT_GLOBAL_WRITE); }) .Case([&](IREE::Stream::TensorExportOp op) { - removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL); + auto sourceType = + llvm::cast(op.getSource().getType()); + switch (sourceType.getLifetime()) { + default: + case IREE::Stream::Lifetime::External: + removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL); + break; + case IREE::Stream::Lifetime::Staging: + removeAssumedBits(NOT_MUTATED | NOT_STAGING_READ | + NOT_STAGING_WRITE | NOT_TRANSFER_READ | + NOT_TRANSFER_WRITE); + break; + case IREE::Stream::Lifetime::Transient: + removeAssumedBits(NOT_MUTATED | NOT_TRANSFER_READ | + NOT_TRANSFER_WRITE | NOT_DISPATCH_READ | + NOT_DISPATCH_WRITE); + break; + case IREE::Stream::Lifetime::Variable: + removeAssumedBits(NOT_MUTATED | NOT_TRANSFER_READ | + NOT_TRANSFER_WRITE | NOT_DISPATCH_READ | + NOT_DISPATCH_WRITE); + break; + case IREE::Stream::Lifetime::Constant: + removeAssumedBits(NOT_CONSTANT | NOT_TRANSFER_READ | + NOT_DISPATCH_READ); + break; + } }) .Case([&](IREE::Stream::AsyncCloneOp op) { removeAssumedBits(NOT_TRANSFER_READ); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp index 0ed4955461859..b2a6abf26f66f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp @@ -249,12 +249,12 @@ struct ConvertTensorTraceOp llvm::zip_equal(op.getOperands(), adaptor.getOperands())) { auto source = consumeTensorOperand(op.getLoc(), resourceOperand, rewriter); - auto externalType = rewriter.getType( - IREE::Stream::Lifetime::External); + auto stagingType = rewriter.getType( + IREE::Stream::Lifetime::Staging); auto exportSource = resourceOperand; - if (source.resource.getType() != externalType) { + if (source.resource.getType() != stagingType) { exportSource = rewriter.create( - op.getLoc(), externalType, source.resource, source.resourceSize, + op.getLoc(), stagingType, source.resource, source.resourceSize, source.resourceSize, /*source_affinity=*/getAffinityFor(op), /*result_affinity=*/nullptr); diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel index 0644bdaac80f2..4dcde8c579a98 100644 --- a/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel @@ -22,6 +22,7 @@ iree_compiler_cc_library( ], deps = [ "//compiler/src/iree/compiler/Dialect/HAL/Conversion", + "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/VM/Conversion", "//compiler/src/iree/compiler/Modules/Check/IR", "@llvm-project//mlir:Pass", diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt index 582a6ada281e9..c55d7713656e1 100644 --- a/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt +++ b/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt @@ -21,6 +21,7 @@ iree_cc_library( MLIRPass MLIRTransforms iree::compiler::Dialect::HAL::Conversion + iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::VM::Conversion iree::compiler::Modules::Check::IR PUBLIC diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp index 82da66bd72006..10cdbb35684ac 100644 --- a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp +++ b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp @@ -7,6 +7,8 @@ #include "iree/compiler/Modules/Check/Conversion/ConversionPatterns.h" #include "iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h" +#include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h" #include "iree/compiler/Modules/Check/IR/CheckOps.h" #include "mlir/Pass/Pass.h" @@ -60,17 +62,90 @@ void populateCheckToVMPatterns(MLIRContext *context, SymbolTable &importSymbols, context, importSymbols, typeConverter, "check.expect_almost_eq"); } +// Attempts to rewrite an op that may use tensor values into an op using HAL +// buffers. +static LogicalResult applyDefaultCheckBufferRewrite( + Operation *srcOp, ValueRange operands, StringRef dstOpName, + TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { + OperationState state{srcOp->getLoc(), dstOpName}; + state.addAttributes(srcOp->getAttrs()); + + // Add device argument. + Value device = rewriter.create(srcOp->getLoc()); + state.addOperands({device}); + + for (auto [srcOperand, dstOperand] : + llvm::zip_equal(srcOp->getOperands(), operands)) { + // Check that any type that should have been mapped to buffer view was. + // This is just to catch conflicts in type conversions that may sneak in + // during development. + assert( + (!HALTypeConverter::shouldConvertToBufferView(srcOperand.getType()) || + dstOperand.getType().isa()) && + "expect that tensors have been mapped to buffer views"); + state.addOperands({dstOperand}); + } + for (auto resultType : srcOp->getResultTypes()) { + if (HALTypeConverter::shouldConvertToBufferView(resultType)) { + state.addTypes(IREE::HAL::BufferViewType::get(rewriter.getContext())); + } else { + // Normal pass-through result. + if (failed(typeConverter.convertType(resultType, state.types))) { + return failure(); + } + } + } + + auto *dstOp = rewriter.create(state); + rewriter.replaceOp(srcOp, dstOp->getResults()); + return success(); +} + +// HAL tensor-to-buffer conversion utility. +// This can be used by dialects to model custom op conversion from a dialect +// that uses the MLIR tensor type to the IREE HAL buffer type. At this point +// during conversion the source values will be TensorType and the target values +// will be IREE::HAL::BufferTypes. Any static information available about the +// tensor (such as static dimensions, element type, layout, etc) are extracted +// here and lowered as expanded values. +// +// The ABI is currently very basic and will change with the introduction of more +// dynamic shape logic. +// +// Source: +// my.tensor_op(%arg0 : tensor<2x4xf32>) +// Target: +// %arg0_view = hal.buffer_view.create %arg0, ... +// my.buffer_op(%arg0_view : !hal.buffer_view) +template +class HALCheckOpConversion : public OpConversionPattern { +public: + HALCheckOpConversion(MLIRContext *context, TypeConverter &typeConverter) + : OpConversionPattern(context), typeConverter(typeConverter) {} + + LogicalResult + matchAndRewrite(SRC srcOp, typename SRC::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return applyDefaultCheckBufferRewrite(srcOp, adaptor.getOperands(), + DST::getOperationName(), + typeConverter, rewriter); + } + +protected: + TypeConverter &typeConverter; +}; + void populateCheckToHALPatterns(MLIRContext *context, RewritePatternSet &patterns, TypeConverter &typeConverter) { // The same op handles both tensors and buffer views. - patterns - .insert, - HALOpConversion, - HALOpConversion>(context, - typeConverter); + patterns.insert< + HALCheckOpConversion, + HALCheckOpConversion, + HALCheckOpConversion>(context, + typeConverter); } } // namespace Check diff --git a/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel index dff0294ea0b40..e55f3d25bd894 100644 --- a/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel @@ -57,6 +57,7 @@ iree_compiler_cc_library( ":IR", ":check_ops_gen", "//compiler/src/iree/compiler/Dialect/HAL/Conversion", + "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", "//compiler/src/iree/compiler/Dialect/VM/Conversion", "//compiler/src/iree/compiler/Modules/Check:check_imports", "//compiler/src/iree/compiler/Modules/Check/Conversion", diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt b/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt index c3a85740d27ec..b0928ce62b9fc 100644 --- a/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt @@ -42,6 +42,7 @@ iree_cc_library( MLIRParser MLIRTransforms iree::compiler::Dialect::HAL::Conversion + iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::VM::Conversion iree::compiler::Modules::Check::Conversion iree::compiler::Modules::Check::check_imports diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp b/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp index dbdb4e19f8374..554baa6084e96 100644 --- a/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp +++ b/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Modules/Check/IR/CheckDialect.h" #include "iree/compiler/Dialect/HAL/Conversion/ConversionDialectInterface.h" +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h" #include "iree/compiler/Modules/Check/Conversion/ConversionPatterns.h" #include "iree/compiler/Modules/Check/IR/CheckOps.h" @@ -57,6 +58,8 @@ class CheckToHalConversionInterface : public HALConversionDialectInterface { CheckDialect::CheckDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context, TypeID::get()) { + context->loadDialect(); + addInterfaces(); addInterfaces(); #define GET_OP_LIST diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp index a651bfedcad6e..69cfda7f104b3 100644 --- a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp +++ b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp @@ -24,7 +24,7 @@ struct ExpandAttributeToConst : public OpRewritePattern { LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { auto rhs = rewriter.create(op.getLoc(), op.getValue()); - rewriter.replaceOpWithNewOp(op, op.getLhs(), rhs); + rewriter.replaceOpWithNewOp(op, op.getDevice(), op.getLhs(), rhs); return success(); } }; diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td index 9d0b1b335c6c7..59c2236c2ada3 100644 --- a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td +++ b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td @@ -36,7 +36,6 @@ def CHECK_ExpectTrueOp : Op { let assemblyFormat = "`(` $operand `)` attr-dict `:` type($operand)"; } - def CHECK_ExpectFalseOp : Op { let summary = [{Checks that the operand is false}]; let description = [{ @@ -64,18 +63,24 @@ def CHECK_ExpectAllTrueOp : Op { Issues a non-fatal failure if the verification fails. ```mlir - check.expect_all_true(%arg0) : !hal.buffer_view + check.expect_all_true<%device>(%arg0) : !hal.buffer_view check.expect_all_true(%arg1) : tensor<2x2xi32> ``` }]; - let arguments = - (ins AnyTypeOf<[HAL_BufferView, TensorOf<[AnySignlessInteger]>]>:$operand); + let arguments = (ins + Optional:$device, + AnyTypeOf<[HAL_BufferView, TensorOf<[AnySignlessInteger]>]>:$operand + ); - let assemblyFormat = "`(` $operand `)` attr-dict `:` type($operand)"; + let assemblyFormat = [{ + (`` `<` $device^ `>`)? + `` `(` $operand `)` attr-dict `:` type($operand) + }]; } -def CHECK_ExpectEqOp : Op { +def CHECK_ExpectEqOp : + Op]> { let summary = [{Checks that the tensor or buffer view operands are equal}]; let description = [{ Verifies that the operands are exactly equal. @@ -88,11 +93,15 @@ def CHECK_ExpectEqOp : Op { }]; let arguments = (ins - AnyTypeOf<[HAL_BufferView, AnyTensor]>:$lhs, - AnyTypeOf<[HAL_BufferView, AnyTensor]>:$rhs + Optional:$device, + AnyTypeOf<[HAL_BufferView, AnyTensor]>:$lhs, + AnyTypeOf<[HAL_BufferView, AnyTensor]>:$rhs ); - let assemblyFormat = "`(` $lhs `,` $rhs `)` attr-dict `:` type($lhs)"; + let assemblyFormat = [{ + (`` `<` $device^ `>`)? + `` `(` $lhs `,` $rhs `)` attr-dict `:` type($lhs) + }]; } def CHECK_ExpectEqConstOp : @@ -111,17 +120,21 @@ def CHECK_ExpectEqConstOp : }]; let arguments = (ins + Optional:$device, AnyTensor:$lhs, ElementsAttr:$value ); let hasCanonicalizer = 1; - let assemblyFormat = "`(` $lhs `,` $value `)` attr-dict `:` type($lhs)"; + let assemblyFormat = [{ + (`` `<` $device^ `>`)? + `` `(` $lhs `,` $value `)` attr-dict `:` type($lhs) + }]; } def CHECK_ExpectAlmostEqOp : - Op { + Op]> { let summary = [{Checks that the operands are almost equal}]; let description = [{ Verifies that the buffer view or tensor operands with float elements are @@ -135,11 +148,15 @@ def CHECK_ExpectAlmostEqOp : }]; let arguments = (ins - AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$lhs, - AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$rhs + Optional:$device, + AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$lhs, + AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$rhs ); - let assemblyFormat = "`(` $lhs `,` $rhs `)` attr-dict `:` type($lhs)"; + let assemblyFormat = [{ + (`` `<` $device^ `>`)? + `` `(` $lhs `,` $rhs `)` attr-dict `:` type($lhs) + }]; } def CHECK_ExpectAlmostEqConstOp : @@ -160,13 +177,17 @@ def CHECK_ExpectAlmostEqConstOp : }]; let arguments = (ins + Optional:$device, TensorOf<[AnyFloat]>:$lhs, ElementsAttr:$value ); let hasCanonicalizer = 1; - let assemblyFormat = "`(` $lhs `,` $value `)` attr-dict `:` type($lhs)"; + let assemblyFormat = [{ + (`` `<` $device^ `>`)? + `` `(` $lhs `,` $value `)` attr-dict `:` type($lhs) + }]; } #endif // IREE_MODULES_CHECK_DIALECT_CHECK_OPS diff --git a/compiler/src/iree/compiler/Modules/Check/check.imports.mlir b/compiler/src/iree/compiler/Modules/Check/check.imports.mlir index 67bae93437b71..63b9d72392b63 100644 --- a/compiler/src/iree/compiler/Modules/Check/check.imports.mlir +++ b/compiler/src/iree/compiler/Modules/Check/check.imports.mlir @@ -15,15 +15,18 @@ vm.import private optional @expect_false( ) vm.import private optional @expect_all_true( + %device : !vm.ref, %operand : !vm.ref, ) vm.import private optional @expect_eq( + %device : !vm.ref, %lhs : !vm.ref, %rhs : !vm.ref ) vm.import private optional @expect_almost_eq( + %device : !vm.ref, %lhs : !vm.ref, %rhs : !vm.ref ) diff --git a/experimental/cuda2/cuda_device.c b/experimental/cuda2/cuda_device.c index 370b693ea0cfc..404dd357a7cdc 100644 --- a/experimental/cuda2/cuda_device.c +++ b/experimental/cuda2/cuda_device.c @@ -625,7 +625,7 @@ static iree_status_t iree_hal_cuda2_device_queue_alloca( // allocator is set on the device. iree_status_t status = iree_ok_status(); if (device->supports_memory_pools && - !iree_any_bit_set(params.access, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { + !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { status = iree_hal_cuda2_memory_pools_alloca( &device->memory_pools, device->dispatch_cu_stream, pool, params, allocation_size, out_buffer); diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c index 2c5a91b6caf4e..ea38514abdc7c 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c @@ -570,7 +570,7 @@ static iree_status_t iree_hal_cuda_device_queue_alloca( // allocator is set on the device. iree_status_t status = iree_ok_status(); if (device->supports_memory_pools && - !iree_any_bit_set(params.access, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { + !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { status = iree_hal_cuda_memory_pools_alloca(&device->memory_pools, device->stream, pool, params, allocation_size, out_buffer); diff --git a/runtime/src/iree/modules/check/check_test.cc b/runtime/src/iree/modules/check/check_test.cc index 67f194718bc56..7623fb5df5628 100644 --- a/runtime/src/iree/modules/check/check_test.cc +++ b/runtime/src/iree/modules/check/check_test.cc @@ -197,6 +197,9 @@ class CheckTest : public ::testing::Test { IREE_RETURN_IF_ERROR( iree_vm_list_create(iree_vm_make_undefined_type_def(), args.size(), iree_allocator_system(), &inputs_)); + iree_vm_ref_t device_ref = iree_hal_device_retain_ref(device_); + IREE_RETURN_IF_ERROR( + iree_vm_list_push_ref_move(inputs_.get(), &device_ref)); for (auto& arg : args) { iree_vm_ref_t arg_ref = iree_hal_buffer_view_move_ref(arg.get()); IREE_RETURN_IF_ERROR(iree_vm_list_push_ref_move(inputs_.get(), &arg_ref)); diff --git a/runtime/src/iree/modules/check/module.cc b/runtime/src/iree/modules/check/module.cc index b417eef5ac8c5..edbb9fe994762 100644 --- a/runtime/src/iree/modules/check/module.cc +++ b/runtime/src/iree/modules/check/module.cc @@ -155,6 +155,100 @@ Status ExpectAllTrue(iree_byte_span_t bytes, "unsupported element type %s", element_type_str); } +static StatusOr>> +TransferBuffersToHost( + iree_hal_device_t* device, + const iree::span> source_views) { + IREE_TRACE_SCOPE(); + + // If all buffers are already host-accessible we can skip the transfer. + std::vector> target_views; + bool requires_transfer = false; + for (auto& source_view : source_views) { + iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(source_view.get()); + if (!iree_all_bits_set(iree_hal_buffer_memory_type(buffer), + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) || + !iree_all_bits_set(iree_hal_buffer_allowed_usage(buffer), + IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)) { + requires_transfer = true; + } + } + if (!requires_transfer) { + for (auto& source_view : source_views) target_views.push_back(source_view); + return std::move(target_views); + } + + vm::ref command_buffer; + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create( + device, + IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT | + IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION, + IREE_HAL_COMMAND_CATEGORY_TRANSFER, IREE_HAL_QUEUE_AFFINITY_ANY, 0, + &command_buffer)); + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_begin(command_buffer.get())); + + iree_hal_buffer_params_t target_params = { + /*.usage=*/IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + /*.access=*/IREE_HAL_MEMORY_ACCESS_ALL, + /*.type=*/ + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + /*.queue_affinity=*/IREE_HAL_QUEUE_AFFINITY_ANY, + /*.min_alignment=*/0, + }; + for (size_t i = 0; i < source_views.size(); ++i) { + iree_hal_buffer_t* source_buffer = + iree_hal_buffer_view_buffer(source_views[i].get()); + iree_device_size_t buffer_length = + iree_hal_buffer_byte_length(source_buffer); + vm::ref target_buffer; + IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( + iree_hal_device_allocator(device), target_params, buffer_length, + &target_buffer)); + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_copy_buffer( + command_buffer.get(), source_buffer, 0, target_buffer.get(), 0, + buffer_length)); + vm::ref target_view; + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create_like( + target_buffer.get(), source_views[i].get(), + iree_hal_device_host_allocator(device), &target_view)); + target_views.push_back(std::move(target_view)); + } + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_end(command_buffer.get())); + vm::ref semaphore; + IREE_RETURN_IF_ERROR(iree_hal_semaphore_create(device, 0ull, &semaphore)); + vm::ref fence; + IREE_RETURN_IF_ERROR(iree_hal_fence_create_at( + semaphore.get(), 1ull, iree_hal_device_host_allocator(device), &fence)); + IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute( + device, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(), + iree_hal_fence_semaphore_list(fence.get()), 1, &command_buffer)); + IREE_RETURN_IF_ERROR( + iree_hal_fence_wait(fence.get(), iree_infinite_timeout())); + return std::move(target_views); +} + +static Status TransferToHost(iree_hal_device_t* device, + vm::ref& buffer_view) { + IREE_TRACE_SCOPE(); + IREE_ASSIGN_OR_RETURN(auto target_views, + TransferBuffersToHost(device, {buffer_view})); + buffer_view = std::move(target_views[0]); + return OkStatus(); +} + +static Status TransferToHost(iree_hal_device_t* device, + vm::ref& buffer_view_a, + vm::ref& buffer_view_b) { + IREE_TRACE_SCOPE(); + IREE_ASSIGN_OR_RETURN( + auto target_views, + TransferBuffersToHost(device, {buffer_view_a, buffer_view_b})); + buffer_view_a = std::move(target_views[0]); + buffer_view_b = std::move(target_views[1]); + return OkStatus(); +} + // Per-context module state. // This can contain "globals" and other arbitrary state. // @@ -177,7 +271,9 @@ class CheckModuleState final { return OkStatus(); } - Status ExpectAllTrue(vm::ref operand) { + Status ExpectAllTrue(vm::ref device, + vm::ref operand) { + IREE_RETURN_IF_ERROR(TransferToHost(device.get(), operand)); auto* view = operand.get(); iree_hal_element_type_t element_type = iree_hal_buffer_view_element_type(view); @@ -193,8 +289,10 @@ class CheckModuleState final { return OkStatus(); } - Status ExpectEq(vm::ref lhs_ref, + Status ExpectEq(vm::ref device, + vm::ref lhs_ref, vm::ref rhs_ref) { + IREE_RETURN_IF_ERROR(TransferToHost(device.get(), lhs_ref, rhs_ref)); auto* lhs = lhs_ref.get(); auto* rhs = rhs_ref.get(); @@ -272,8 +370,10 @@ class CheckModuleState final { return OkStatus(); } - Status ExpectAlmostEq(vm::ref lhs_ref, + Status ExpectAlmostEq(vm::ref device, + vm::ref lhs_ref, vm::ref rhs_ref) { + IREE_RETURN_IF_ERROR(TransferToHost(device.get(), lhs_ref, rhs_ref)); auto* lhs = lhs_ref.get(); auto* rhs = rhs_ref.get(); diff --git a/runtime/src/iree/modules/check/test/success.mlir b/runtime/src/iree/modules/check/test/success.mlir index ff5aa8e8d599c..40d8bc33e7338 100644 --- a/runtime/src/iree/modules/check/test/success.mlir +++ b/runtime/src/iree/modules/check/test/success.mlir @@ -14,9 +14,10 @@ func.func @expect_false() { } func.func @expect_all_true() { + %device = hal.ex.shared_device : !hal.device %all_true = util.unfoldable_constant dense<1> : tensor<2x2xi32> %all_true_view = hal.tensor.export %all_true : tensor<2x2xi32> -> !hal.buffer_view - check.expect_all_true(%all_true_view) : !hal.buffer_view + check.expect_all_true<%device>(%all_true_view) : !hal.buffer_view return } diff --git a/runtime/src/iree/modules/hal/types.c b/runtime/src/iree/modules/hal/types.c index 0c7e0d7900f93..52ce5a281a523 100644 --- a/runtime/src/iree/modules/hal/types.c +++ b/runtime/src/iree/modules/hal/types.c @@ -205,7 +205,7 @@ IREE_API_EXPORT iree_hal_buffer_t* iree_vm_list_get_buffer_retain( IREE_API_EXPORT iree_status_t iree_vm_list_set_buffer_retain( iree_vm_list_t* list, iree_host_size_t i, iree_hal_buffer_t* value) { - iree_vm_ref_t value_ref; + iree_vm_ref_t value_ref = iree_vm_ref_null(); IREE_RETURN_IF_ERROR( iree_vm_ref_wrap_assign(value, iree_hal_buffer_type(), &value_ref)); return iree_vm_list_set_ref_retain(list, i, &value_ref); @@ -226,7 +226,7 @@ IREE_API_EXPORT iree_hal_buffer_view_t* iree_vm_list_get_buffer_view_retain( IREE_API_EXPORT iree_status_t iree_vm_list_set_buffer_view_retain( iree_vm_list_t* list, iree_host_size_t i, iree_hal_buffer_view_t* value) { - iree_vm_ref_t value_ref; + iree_vm_ref_t value_ref = iree_vm_ref_null(); IREE_RETURN_IF_ERROR( iree_vm_ref_wrap_assign(value, iree_hal_buffer_view_type(), &value_ref)); return iree_vm_list_set_ref_retain(list, i, &value_ref); @@ -247,7 +247,7 @@ IREE_API_EXPORT iree_hal_fence_t* iree_vm_list_get_fence_retain( IREE_API_EXPORT iree_status_t iree_vm_list_set_fence_retain( iree_vm_list_t* list, iree_host_size_t i, iree_hal_fence_t* value) { - iree_vm_ref_t value_ref; + iree_vm_ref_t value_ref = iree_vm_ref_null(); IREE_RETURN_IF_ERROR( iree_vm_ref_wrap_assign(value, iree_hal_fence_type(), &value_ref)); return iree_vm_list_set_ref_retain(list, i, &value_ref); diff --git a/runtime/src/iree/tooling/run_module.c b/runtime/src/iree/tooling/run_module.c index ad5e674fedb1f..2af3db4f891b6 100644 --- a/runtime/src/iree/tooling/run_module.c +++ b/runtime/src/iree/tooling/run_module.c @@ -246,6 +246,22 @@ static iree_status_t iree_tooling_run_function( "processing instrument data"); } + // Transfer outputs to the host so they can be processed. Only required when + // using full HAL device-based execution. + if (iree_status_is_ok(status) && device != NULL) { + iree_hal_buffer_params_t target_params = { + .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + .access = IREE_HAL_MEMORY_ACCESS_ALL, + .type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + .queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY, + .min_alignment = 0, + }; + status = iree_tooling_transfer_variant_list( + device, outputs, device_allocator, target_params, + /*wait_fence=*/NULL, /*signal_fence=*/NULL); + } + // Handle either printing/writing the outputs or checking them against // expected values (basic pass/fail testing). if (iree_status_is_ok(status)) { diff --git a/runtime/src/iree/tooling/vm_util.c b/runtime/src/iree/tooling/vm_util.c index b21eadaee78aa..70e2e777bf970 100644 --- a/runtime/src/iree/tooling/vm_util.c +++ b/runtime/src/iree/tooling/vm_util.c @@ -324,6 +324,187 @@ iree_status_t iree_tooling_append_async_fence_inputs( return status; } +static bool iree_tooling_requires_buffer_transfer( + iree_hal_buffer_t* source_buffer, iree_hal_buffer_params_t target_params) { + return !iree_all_bits_set(iree_hal_buffer_memory_type(source_buffer), + target_params.type) || + !iree_all_bits_set(iree_hal_buffer_allowed_usage(source_buffer), + target_params.usage); +} + +static iree_status_t iree_tooling_setup_buffer_transfer( + iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* source_buffer, + iree_hal_allocator_t* target_allocator, + iree_hal_buffer_params_t target_params, + iree_hal_buffer_t** out_target_buffer) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(source_buffer); + IREE_ASSERT_ARGUMENT(target_allocator); + IREE_ASSERT_ARGUMENT(out_target_buffer); + *out_target_buffer = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_buffer_t* target_buffer = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_allocator_allocate_buffer( + target_allocator, target_params, + iree_hal_buffer_allocation_size(source_buffer), &target_buffer)); + + iree_status_t status = iree_hal_command_buffer_copy_buffer( + command_buffer, source_buffer, 0, target_buffer, 0, + iree_hal_buffer_byte_length(source_buffer)); + + if (iree_status_is_ok(status)) { + *out_target_buffer = target_buffer; + } else { + iree_hal_buffer_release(target_buffer); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_tooling_submit_transfer( + iree_hal_device_t* device, iree_hal_fence_t* wait_fence, + iree_hal_queue_affinity_t queue_affinity, + iree_hal_command_buffer_t* command_buffer, iree_hal_fence_t* signal_fence) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_ok_status(); + + bool needs_wait = signal_fence == NULL; + if (needs_wait) { + iree_hal_semaphore_t* semaphore = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_semaphore_create(device, 0ull, &semaphore)); + status = iree_hal_fence_create_at( + semaphore, 1ull, iree_hal_device_host_allocator(device), &signal_fence); + iree_hal_semaphore_release(semaphore); + } else { + iree_hal_fence_retain(signal_fence); + } + + if (iree_status_is_ok(status)) { + status = iree_hal_device_queue_execute( + device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence), + iree_hal_fence_semaphore_list(signal_fence), 1, &command_buffer); + } + + if (iree_status_is_ok(status) && needs_wait) { + status = iree_hal_fence_wait(signal_fence, iree_infinite_timeout()); + } + + iree_hal_fence_release(signal_fence); + IREE_TRACE_ZONE_END(z0); + return status; +} + +iree_status_t iree_tooling_transfer_variant_list( + iree_hal_device_t* device, iree_vm_list_t* list, + iree_hal_allocator_t* target_allocator, + iree_hal_buffer_params_t target_params, iree_hal_fence_t* wait_fence, + iree_hal_fence_t* signal_fence) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(list); + IREE_ASSERT_ARGUMENT(target_allocator); + IREE_TRACE_ZONE_BEGIN(z0); + + // If all buffers are already host-accessible we can skip the transfer. + bool requires_transfer = false; + for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) { + iree_vm_ref_t value = iree_vm_ref_null(); + IREE_IGNORE_ERROR(iree_vm_list_get_ref_assign(list, i, &value)); + if (iree_hal_buffer_isa(value)) { + iree_hal_buffer_t* source_buffer = iree_hal_buffer_deref(value); + if (iree_tooling_requires_buffer_transfer(source_buffer, target_params)) { + requires_transfer = true; + break; + } + } else if (iree_hal_buffer_view_isa(value)) { + iree_hal_buffer_view_t* source_view = iree_hal_buffer_view_deref(value); + iree_hal_buffer_t* source_buffer = + iree_hal_buffer_view_buffer(source_view); + if (iree_tooling_requires_buffer_transfer(source_buffer, target_params)) { + requires_transfer = true; + break; + } + } + } + if (!requires_transfer) { + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); + } + + iree_hal_command_buffer_t* command_buffer = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_command_buffer_create( + device, + IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT | + IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION, + IREE_HAL_COMMAND_CATEGORY_TRANSFER, target_params.queue_affinity, + /*binding_capacity=*/0, &command_buffer)); + + iree_status_t status = iree_hal_command_buffer_begin(command_buffer); + if (iree_status_is_ok(status)) { + for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) { + iree_vm_ref_t value = iree_vm_ref_null(); + IREE_IGNORE_ERROR(iree_vm_list_get_ref_assign(list, i, &value)); + if (iree_hal_buffer_isa(value)) { + iree_hal_buffer_t* source_buffer = iree_hal_buffer_deref(value); + if (!iree_tooling_requires_buffer_transfer(source_buffer, + target_params)) { + // Already ok. + continue; + } + iree_hal_buffer_t* target_buffer = NULL; + status = iree_tooling_setup_buffer_transfer( + command_buffer, source_buffer, target_allocator, target_params, + &target_buffer); + if (!iree_status_is_ok(status)) break; + status = iree_vm_list_set_buffer_retain(list, i, target_buffer); + iree_hal_buffer_release(target_buffer); + if (!iree_status_is_ok(status)) break; + } else if (iree_hal_buffer_view_isa(value)) { + iree_hal_buffer_view_t* source_view = iree_hal_buffer_view_deref(value); + iree_hal_buffer_t* source_buffer = + iree_hal_buffer_view_buffer(source_view); + if (!iree_tooling_requires_buffer_transfer(source_buffer, + target_params)) { + // Already ok. + continue; + } + iree_hal_buffer_t* target_buffer = NULL; + status = iree_tooling_setup_buffer_transfer( + command_buffer, source_buffer, target_allocator, target_params, + &target_buffer); + if (!iree_status_is_ok(status)) break; + iree_hal_buffer_view_t* target_view = NULL; + status = iree_hal_buffer_view_create_like( + target_buffer, source_view, + iree_hal_allocator_host_allocator(target_allocator), &target_view); + iree_hal_buffer_release(target_buffer); + if (!iree_status_is_ok(status)) break; + status = iree_vm_list_set_buffer_view_retain(list, i, target_view); + iree_hal_buffer_view_release(target_view); + if (!iree_status_is_ok(status)) break; + } + } + } + if (iree_status_is_ok(status)) { + status = iree_hal_command_buffer_end(command_buffer); + } + + if (iree_status_is_ok(status)) { + status = iree_tooling_submit_transfer(device, wait_fence, + target_params.queue_affinity, + command_buffer, signal_fence); + } + + iree_hal_command_buffer_release(command_buffer); + + IREE_TRACE_ZONE_END(z0); + return status; +} + #define IREE_PRINTVARIANT_CASE_I(SIZE, B, V) \ case IREE_VM_VALUE_TYPE_I##SIZE: \ return iree_string_builder_append_format( \ diff --git a/runtime/src/iree/tooling/vm_util.h b/runtime/src/iree/tooling/vm_util.h index e2a0311923780..bc9ca008236b9 100644 --- a/runtime/src/iree/tooling/vm_util.h +++ b/runtime/src/iree/tooling/vm_util.h @@ -54,6 +54,16 @@ iree_status_t iree_tooling_append_async_fence_inputs( iree_hal_device_t* device, iree_hal_fence_t* wait_fence, iree_hal_fence_t** out_signal_fence); +// Transfers all buffers in |list| to ones using |target_params|. +// If no |wait_fence| is provided then the transfer will begin immediately. +// If no |signal_fence| is provided then the call will block until the transfer +// completes. +iree_status_t iree_tooling_transfer_variant_list( + iree_hal_device_t* device, iree_vm_list_t* list, + iree_hal_allocator_t* target_allocator, + iree_hal_buffer_params_t target_params, iree_hal_fence_t* wait_fence, + iree_hal_fence_t* signal_fence); + // Appends a variant list of VM scalars and buffers to |builder|. // |list_name| will be printed alongside each element ordinal. // diff --git a/tools/BUILD.bazel b/tools/BUILD.bazel index 34ace600ce1a4..0f376f0507e2e 100644 --- a/tools/BUILD.bazel +++ b/tools/BUILD.bazel @@ -210,6 +210,7 @@ iree_runtime_cc_binary( "//runtime/src/iree/modules/hal", "//runtime/src/iree/tooling:device_util", "//runtime/src/iree/tooling:trace_replay", + "//runtime/src/iree/tooling:vm_util", "//runtime/src/iree/tooling:yaml_util", "//runtime/src/iree/vm", "@com_github_yaml_libyaml//:yaml", diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 8403cfcd3bb46..28878f4fb0ea2 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -215,6 +215,7 @@ iree_cc_binary( iree::modules::hal iree::tooling::device_util iree::tooling::trace_replay + iree::tooling::vm_util iree::tooling::yaml_util iree::vm yaml diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c index d1082a216bb08..5730608682d27 100644 --- a/tools/iree-e2e-matmul-test.c +++ b/tools/iree-e2e-matmul-test.c @@ -19,6 +19,7 @@ #include "iree/modules/hal/module.h" #include "iree/tooling/device_util.h" #include "iree/tooling/trace_replay.h" +#include "iree/tooling/vm_util.h" #include "iree/tooling/yaml_util.h" #include "iree/vm/api.h" @@ -192,10 +193,8 @@ static iree_status_t map_host_local_row_major_data( iree_hal_buffer_view_t* buffer_view, enum iree_hal_memory_access_bits_t access, iree_hal_buffer_mapping_t* mapping) { - // Really validate host-local, not just host-visible: callers may rely on - // host-coherency. IREE_RETURN_IF_ERROR( - validate_memory_type(buffer_view, IREE_HAL_MEMORY_TYPE_HOST_LOCAL)); + validate_memory_type(buffer_view, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); if (iree_hal_buffer_view_encoding_type(buffer_view) != IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR) { return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, @@ -1014,42 +1013,46 @@ static iree_status_t do_matmul_and_check_results( replay->device, device_allocator, device_inputs, &host_inputs)); // Invoke the function to produce the actual result. - iree_vm_list_t* device_outputs = NULL; + iree_vm_list_t* outputs = NULL; IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), /*initial_capacity=*/8, - replay->host_allocator, &device_outputs)); + replay->host_allocator, &outputs)); IREE_CHECK_OK(iree_vm_invoke( replay->context, function, IREE_VM_INVOCATION_FLAG_NONE, - /*policy=*/NULL, device_inputs, device_outputs, replay->host_allocator)); + /*policy=*/NULL, device_inputs, outputs, replay->host_allocator)); iree_vm_list_release(device_inputs); - // Get the device_actual_result from the device_outputs. - iree_hal_buffer_view_t* device_actual_result; - IREE_CHECK_OK( - get_item_as_buffer_view(device_outputs, 0, &device_actual_result)); + // Transfer device buffers to host buffers. + iree_hal_buffer_params_t host_params = { + .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + .access = IREE_HAL_MEMORY_ACCESS_ALL, + .type = + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + .queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY, + .min_alignment = 0, + }; + IREE_CHECK_OK(iree_tooling_transfer_variant_list( + replay->device, outputs, device_allocator, host_params, + /*wait_fence=*/NULL, /*signal_fence=*/NULL)); - // Copy the results to a host local buffer to be able to map it. - iree_hal_buffer_view_t* host_actual_result = NULL; - IREE_CHECK_OK(copy_device_buffer_view_to_host( - replay->device, device_allocator, device_actual_result, - &host_actual_result)); + // Get the actual result computed by the program. + iree_hal_buffer_view_t* actual_result; + IREE_CHECK_OK(get_item_as_buffer_view(outputs, 0, &actual_result)); - // Allocate host_expected_result with same shape as host_actual_result. + // Allocate host_expected_result with same shape as actual_result. iree_hal_buffer_view_t* host_expected_result = NULL; - IREE_CHECK_OK(allocate_host_buffer_view_like(replay->device, device_allocator, - host_actual_result, - &host_expected_result)); + IREE_CHECK_OK(allocate_host_buffer_view_like( + replay->device, device_allocator, actual_result, &host_expected_result)); // Use the reference matmul implementation to fill host_expected_result IREE_CHECK_OK(reference_matmul(host_inputs, host_expected_result)); - // Check that host_actual_result and host_expected_result agree. - iree_status_t status = check_matmul_results( - file, host_inputs, host_actual_result, host_expected_result); + // Check that actual_result and host_expected_result agree. + iree_status_t status = check_matmul_results(file, host_inputs, actual_result, + host_expected_result); - iree_vm_list_release(device_outputs); // releases device_actual_result + iree_vm_list_release(outputs); // releases actual_result iree_vm_list_release(host_inputs); - iree_hal_buffer_view_release(host_actual_result); iree_hal_buffer_view_release(host_expected_result); return status; } diff --git a/tools/iree-run-trace-main.c b/tools/iree-run-trace-main.c index fa46810c28df2..b1b39dc6caaa8 100644 --- a/tools/iree-run-trace-main.c +++ b/tools/iree-run-trace-main.c @@ -197,6 +197,21 @@ static iree_status_t iree_run_trace_file(iree_string_view_t root_path, yaml_parser_delete(&parser); + // Transfer outputs to the host so they can be processed. + if (iree_status_is_ok(status) && replay.device != NULL) { + iree_hal_buffer_params_t target_params = { + .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + .access = IREE_HAL_MEMORY_ACCESS_ALL, + .type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + .queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY, + .min_alignment = 0, + }; + status = iree_tooling_transfer_variant_list( + replay.device, replay.outputs, iree_hal_device_allocator(replay.device), + target_params, /*wait_fence=*/NULL, /*signal_fence=*/NULL); + } + // Optionally process outputs from the replay session. if (iree_status_is_ok(status)) { if (FLAG_output_list().count == 0) { From d956cfae2fa74e0a23ba99a8c4ca87caa6618e32 Mon Sep 17 00:00:00 2001 From: stanley-nod Date: Mon, 18 Sep 2023 20:58:31 -0700 Subject: [PATCH 42/44] [Bindings] Make copies to local host when map is unavailable. --- runtime/bindings/python/hal.cc | 83 +++++++++++++++++++ runtime/bindings/python/hal.h | 11 +++ .../python/iree/runtime/array_interop.py | 49 ++++++++++- runtime/bindings/python/tests/hal_test.py | 46 ++++++++++ .../iree/hal/drivers/cuda/cuda_allocator.c | 6 +- 5 files changed, 191 insertions(+), 4 deletions(-) diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc index f809527377024..d32eee8a203fa 100644 --- a/runtime/bindings/python/hal.cc +++ b/runtime/bindings/python/hal.cc @@ -57,6 +57,18 @@ static const char kHalDeviceQueueExecute[] = signal_semaphores: Semaphores/Fence to signal. )"; +static const char kHalDeviceQueueCopy[] = + R"(Copy data from a source buffer to destination buffer. + +Args: + source_buffer: `HalBuffer` that holds src data. + target_buffer: `HalBuffer` that will receive data. + wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or + a HalFence. The allocation will be made once these semaphores are + satisfied. + signal_semaphores: Semaphores/Fence to signal. +)"; + static const char kHalFenceWait[] = R"(Waits until the fence is signalled or errored. @@ -519,6 +531,69 @@ void HalDevice::QueueExecute(py::handle command_buffers, "executing command buffers"); } +void HalDevice::QueueCopy(HalBuffer& source_buffer, HalBuffer& target_buffer, + py::handle wait_semaphores, + py::handle signal_semaphores) { + iree_hal_semaphore_list_t wait_list; + iree_hal_semaphore_list_t signal_list; + + // Wait list. + if (py::isinstance(wait_semaphores)) { + wait_list = iree_hal_fence_semaphore_list( + py::cast(wait_semaphores)->raw_ptr()); + } else { + size_t wait_count = py::len(wait_semaphores); + wait_list = { + wait_count, + /*semaphores=*/ + static_cast( + alloca(sizeof(iree_hal_semaphore_t*) * wait_count)), + /*payload_values=*/ + static_cast(alloca(sizeof(uint64_t) * wait_count)), + }; + for (size_t i = 0; i < wait_count; ++i) { + py::tuple pair = wait_semaphores[i]; + wait_list.semaphores[i] = py::cast(pair[0])->raw_ptr(); + wait_list.payload_values[i] = py::cast(pair[1]); + } + } + + // Signal list. + if (py::isinstance(signal_semaphores)) { + signal_list = iree_hal_fence_semaphore_list( + py::cast(signal_semaphores)->raw_ptr()); + } else { + size_t signal_count = py::len(signal_semaphores); + signal_list = { + signal_count, + /*semaphores=*/ + static_cast( + alloca(sizeof(iree_hal_semaphore_t*) * signal_count)), + /*payload_values=*/ + static_cast(alloca(sizeof(uint64_t) * signal_count)), + }; + for (size_t i = 0; i < signal_count; ++i) { + py::tuple pair = signal_semaphores[i]; + signal_list.semaphores[i] = py::cast(pair[0])->raw_ptr(); + signal_list.payload_values[i] = py::cast(pair[1]); + } + } + + // TODO: Accept params for src_offset and target_offset. + iree_device_size_t source_length = + iree_hal_buffer_byte_length(source_buffer.raw_ptr()); + if (source_length != iree_hal_buffer_byte_length(target_buffer.raw_ptr())) { + throw std::invalid_argument( + "Source and target buffer length must match and it does not. Please " + "check allocations"); + } + CheckApiStatus(iree_hal_device_queue_copy( + raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list, + signal_list, source_buffer.raw_ptr(), 0, + target_buffer.raw_ptr(), 0, source_length), + "Copying buffer on queue"); +} + //------------------------------------------------------------------------------ // HalDriver //------------------------------------------------------------------------------ @@ -855,6 +930,9 @@ void SetupHalBindings(nanobind::module_ m) { .def("queue_execute", &HalDevice::QueueExecute, py::arg("command_buffers"), py::arg("wait_semaphores"), py::arg("signal_semaphores"), kHalDeviceQueueExecute) + .def("queue_copy", &HalDevice::QueueCopy, py::arg("source_buffer"), + py::arg("target_buffer"), py::arg("wait_semaphores"), + py::arg("signal_semaphores"), kHalDeviceQueueCopy) .def("__repr__", [](HalDevice& self) { auto id_sv = iree_hal_device_id(self.raw_ptr()); return std::string(id_sv.data, id_sv.size); @@ -957,6 +1035,9 @@ void SetupHalBindings(nanobind::module_ m) { py::class_(m, "HalBuffer") .def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"), py::arg("byte_length")) + .def("byte_length", &HalBuffer::byte_length) + .def("memory_type", &HalBuffer::memory_type) + .def("allowed_usage", &HalBuffer::allowed_usage) .def("create_view", &HalBuffer::CreateView, py::arg("shape"), py::arg("element_size"), py::keep_alive<0, 1>()) .def("map", HalMappedMemory::CreateFromBuffer, py::keep_alive<0, 1>()) @@ -988,6 +1069,8 @@ void SetupHalBindings(nanobind::module_ m) { py::arg("buffer"), py::arg("shape"), py::arg("element_type")); hal_buffer_view .def("map", HalMappedMemory::CreateFromBufferView, py::keep_alive<0, 1>()) + .def("get_buffer", HalBuffer::CreateFromBufferView, + py::keep_alive<0, 1>()) .def_prop_ro("shape", [](HalBufferView& self) { iree_host_size_t rank = diff --git a/runtime/bindings/python/hal.h b/runtime/bindings/python/hal.h index 1fb386307cda8..aa1bdb7613cf7 100644 --- a/runtime/bindings/python/hal.h +++ b/runtime/bindings/python/hal.h @@ -127,6 +127,8 @@ class HalDevice : public ApiRefCounted { py::handle signal_semaphores); void QueueExecute(py::handle command_buffers, py::handle wait_semaphores, py::handle signal_semaphores); + void QueueCopy(HalBuffer& src_buffer, HalBuffer& dst_buffer, + py::handle wait_semaphores, py::handle signal_semaphores); }; class HalDriver : public ApiRefCounted { @@ -175,6 +177,10 @@ class HalBuffer : public ApiRefCounted { return iree_hal_buffer_byte_length(raw_ptr()); } + int memory_type() const { return iree_hal_buffer_memory_type(raw_ptr()); } + + int allowed_usage() const { return iree_hal_buffer_allowed_usage(raw_ptr()); } + void FillZero(iree_device_size_t byte_offset, iree_device_size_t byte_length) { CheckApiStatus( @@ -196,6 +202,11 @@ class HalBuffer : public ApiRefCounted { return HalBufferView::StealFromRawPtr(bv); } + static HalBuffer CreateFromBufferView(HalBufferView& bv) { + return HalBuffer::BorrowFromRawPtr( + iree_hal_buffer_view_buffer(bv.raw_ptr())); + } + py::str Repr(); }; diff --git a/runtime/bindings/python/iree/runtime/array_interop.py b/runtime/bindings/python/iree/runtime/array_interop.py index 096fc9b04dda8..fb67b21c70802 100644 --- a/runtime/bindings/python/iree/runtime/array_interop.py +++ b/runtime/bindings/python/iree/runtime/array_interop.py @@ -17,6 +17,7 @@ HalElementType, MappedMemory, MemoryType, + HalFence, ) __all__ = [ @@ -106,6 +107,20 @@ def to_host(self) -> np.ndarray: self._transfer_to_host(False) return self._host_array + def _is_mappable(self) -> bool: + buffer = self._buffer_view.get_buffer() + if ( + buffer.memory_type() & int(MemoryType.HOST_VISIBLE) + != MemoryType.HOST_VISIBLE + ): + return False + if ( + buffer.allowed_usage() & int(BufferUsage.MAPPING_SCOPED) + != BufferUsage.MAPPING_SCOPED + ): + return False + return True + def _transfer_to_host(self, implicit): if self._host_array is not None: return @@ -114,7 +129,10 @@ def _transfer_to_host(self, implicit): "DeviceArray cannot be implicitly transferred to the host: " "if necessary, do an explicit transfer via .to_host()" ) - self._mapped_memory, self._host_array = self._map_to_host() + if self._is_mappable(): + self._mapped_memory, self._host_array = self._map_to_host() + else: + self._host_array = self._copy_to_host() def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]: # TODO: When synchronization is enabled, need to block here. @@ -129,6 +147,35 @@ def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]: host_array = host_array.astype(self._override_dtype) return mapped_memory, host_array + def _copy_to_host(self) -> np.ndarray: + # TODO: When synchronization is enabled, need to block here. + source_buffer = self._buffer_view.get_buffer() + host_buffer = self._device.allocator.allocate_buffer( + memory_type=(MemoryType.HOST_LOCAL | MemoryType.DEVICE_VISIBLE), + allowed_usage=(BufferUsage.TRANSFER_TARGET | BufferUsage.MAPPING_SCOPED), + allocation_size=source_buffer.byte_length(), + ) + # Copy and wait for buffer to be copied from source buffer. + sem = self._device.create_semaphore(0) + self._device.queue_copy( + source_buffer, + host_buffer, + wait_semaphores=HalFence.create_at(sem, 0), + signal_semaphores=HalFence.create_at(sem, 1), + ) + HalFence.create_at(sem, 1).wait() + # Map and reformat buffer as np.array. + raw_dtype = self._get_raw_dtype() + mapped_memory = host_buffer.map() + host_array = mapped_memory.asarray(self._buffer_view.shape, raw_dtype) + # Detect if we need to force an explicit conversion. This happens when + # we were requested to pretend that the array is in a specific dtype, + # even if that is not representable on the device. You guessed it: + # this is to support bools. + if self._override_dtype is not None and self._override_dtype != raw_dtype: + host_array = host_array.astype(self._override_dtype) + return host_array + def _get_raw_dtype(self): return HalElementType.map_to_dtype(self._buffer_view.element_type) diff --git a/runtime/bindings/python/tests/hal_test.py b/runtime/bindings/python/tests/hal_test.py index cbf45a3741ede..ea1f03167f717 100644 --- a/runtime/bindings/python/tests/hal_test.py +++ b/runtime/bindings/python/tests/hal_test.py @@ -264,6 +264,52 @@ def testFenceExtend(self): fence.extend(iree.runtime.HalFence.create_at(sem2, 2)) self.assertEqual(fence.timepoint_count, 2) + def testRoundTripQueueCopy(self): + original_ary = np.zeros([3, 4], dtype=np.int32) + 2 + source_bv = self.allocator.allocate_buffer_copy( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.DEFAULT, + device=self.device, + buffer=original_ary, + element_type=iree.runtime.HalElementType.SINT_32, + ) + source_buffer = source_bv.get_buffer() + target_buffer = self.allocator.allocate_buffer( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.DEFAULT, + allocation_size=source_buffer.byte_length(), + ) + sem = self.device.create_semaphore(0) + self.device.queue_copy( + source_buffer, + target_buffer, + wait_semaphores=iree.runtime.HalFence.create_at(sem, 0), + signal_semaphores=iree.runtime.HalFence.create_at(sem, 1), + ) + iree.runtime.HalFence.create_at(sem, 1).wait() + copy_ary = target_buffer.map().asarray(original_ary.shape, original_ary.dtype) + np.testing.assert_array_equal(original_ary, copy_ary) + + def testDifferentSizeQueueCopy(self): + source_buffer = self.allocator.allocate_buffer( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.DEFAULT, + allocation_size=12, + ) + target_buffer = self.allocator.allocate_buffer( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.DEFAULT, + allocation_size=13, + ) + sem = self.device.create_semaphore(0) + with self.assertRaisesRegex(ValueError, "length must match"): + self.device.queue_copy( + source_buffer, + target_buffer, + wait_semaphores=iree.runtime.HalFence.create_at(sem, 0), + signal_semaphores=iree.runtime.HalFence.create_at(sem, 1), + ) + def testCommandBufferStartsByDefault(self): cb = iree.runtime.HalCommandBuffer(self.device) with self.assertRaisesRegex(RuntimeError, "FAILED_PRECONDITION"): diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c b/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c index d010be080e650..fcbb2e464f15a 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c @@ -236,12 +236,12 @@ iree_hal_cuda_allocator_query_buffer_compatibility( if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_IMPORTABLE; } + if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; + } // Buffers can only be used on the queue if they are device visible. if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { - if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { - compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; - } if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE)) { compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH; From 675b380983b27878e98e528e619d0ccefdd62d74 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 27 Sep 2023 21:55:54 -0700 Subject: [PATCH 43/44] Add pass for lowering to accel ukernels. Add AccelMatmulExpert pass pipeline Apply clang-format to new C++ files. Remove the currently unused skip-intermediate-roundings option. Reorder entries in BUILD.bazel and CMakeLists.txt alphabetically. (#5) Wrap everything but the factory in the unnamed namespace. (#8) use 'accel' identifier Use parameter struct calling convention Use mmt4d path instead of defining linalg.matmul path. Fix lit test and pass to use mmt4d op. (WIP) Use rank-reduced slices of operands in ukernel call. Fix rank reduction and lit test. Remove isInitializedToZero from accel codegen Fix dims Fix post-accel lowering passes to handle linalg.fill Fix lit test expected output. Correct dims for plugin interface Co-authored-by: Sungsoon Cho --- .../iree/compiler/Codegen/LLVMCPU/BUILD.bazel | 1 + .../compiler/Codegen/LLVMCPU/CMakeLists.txt | 1 + .../LLVMCPU/LLVMCPULowerToAccelUKernels.cpp | 205 ++++++++++++++++++ .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 23 +- .../iree/compiler/Codegen/LLVMCPU/Passes.h | 4 + .../iree/compiler/Codegen/LLVMCPU/Passes.td | 8 + .../compiler/Codegen/LLVMCPU/test/BUILD.bazel | 1 + .../Codegen/LLVMCPU/test/CMakeLists.txt | 1 + .../test/lower_to_accel_ukernel_ops.mlir | 39 ++++ 9 files changed, 281 insertions(+), 2 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToAccelUKernels.cpp create mode 100644 compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_accel_ukernel_ops.mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel index e7b0db3569f15..d89b661cf2ffb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel @@ -57,6 +57,7 @@ iree_compiler_cc_library( "LLVMCPUEmitVectorizationRemarks.cpp", "LLVMCPULinkExecutables.cpp", "LLVMCPULowerExecutableTarget.cpp", + "LLVMCPULowerToAccelUKernels.cpp", "LLVMCPULowerToUKernels.cpp", "LLVMCPUMmt4dVectorLowering.cpp", "LLVMCPUPeel.cpp", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index 692db84dbc5b3..d03d1bdb84c49 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -59,6 +59,7 @@ iree_cc_library( "LLVMCPUEmitVectorizationRemarks.cpp" "LLVMCPULinkExecutables.cpp" "LLVMCPULowerExecutableTarget.cpp" + "LLVMCPULowerToAccelUKernels.cpp" "LLVMCPULowerToUKernels.cpp" "LLVMCPUMmt4dVectorLowering.cpp" "LLVMCPUPeel.cpp" diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToAccelUKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToAccelUKernels.cpp new file mode 100644 index 0000000000000..b8f6b5022702e --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToAccelUKernels.cpp @@ -0,0 +1,205 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/builtins/ukernel/exported_bits.h" +#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h" +#include "iree/compiler/Codegen/Dialect/IREECodegenOps.h" +#include "iree/compiler/Codegen/Dialect/UKernelOps.h" +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +class LLVMCPULowerToAccelUKernelsPass + : public LLVMCPULowerToAccelUKernelsBase { +public: + LLVMCPULowerToAccelUKernelsPass() = default; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override; + + LogicalResult initializeOptions(StringRef options) override { + if (failed(Pass::initializeOptions(options))) { + return failure(); + } + return success(); + } +}; + +/// Holds a function name and attributes. +struct FnNameAndDefAttrs { + std::string name; + SmallVector defAttrs; +}; + +/// Returns the function name and attributes to use for a ukernel with given +/// `ukernelName` on the target described by `targetAttr`. +static FnNameAndDefAttrs +getFnNameAndDefAttrs(const char *ukernelName, RewriterBase &rewriter, + IREE::HAL::ExecutableTargetAttr targetAttr) { + FnNameAndDefAttrs result; + result.name = ukernelName; + result.defAttrs.emplace_back( + rewriter.getStringAttr("hal.import.fields"), + rewriter.getArrayAttr({rewriter.getStringAttr("processor_data"), + rewriter.getStringAttr("processor_id")})); + result.defAttrs.emplace_back( + rewriter.getStringAttr("hal.import.cconv"), + IREE::HAL::CallingConventionAttr::get( + rewriter.getContext(), + IREE::HAL::CallingConvention::ParameterStruct)); + return result; +} + +/// Matches an (linalg.fill -> )? linalg.mmt4d operation sequence and converts +/// it into a iree_codegen.ukernel.generic "accel_matmul_f32" operation, that is +/// later lowered into a call to the microkernel. +static FailureOr +matchDAGForUKernel(RewriterBase &rewriter, linalg::Mmt4DOp op) { + Value lhs = op.getDpsInputOperand(0)->get(); + Value rhs = op.getDpsInputOperand(1)->get(); + Value out = op.getDpsInitOperand(0)->get(); + //auto lhsType = llvm::cast(lhs.getType()); + //auto rhsType = llvm::cast(rhs.getType()); + auto outType = llvm::cast(out.getType()); + /* + Type lhsElemType = lhsType.getElementType(); + Type rhsElemType = rhsType.getElementType(); + Type outElemType = outType.getElementType(); + uint32_t flags = 0; + if (lhsElemType.isSignlessInteger(8) && rhsElemType.isSignlessInteger(8) && + outElemType.isSignlessInteger(32)) { + flags = IREE_UK_FLAG_MMT4D_TYPE_I8I8I32; + } else if (lhsElemType.isF32() && rhsElemType.isF32() && + outElemType.isF32()) { + flags = IREE_UK_FLAG_MMT4D_TYPE_F32F32F32; + } else { + return rewriter.notifyMatchFailure( + op, "unsupported combination of element types"); + } + */ + Location loc = op.getLoc(); + + if (outType.getShape()[0] != 1 || outType.getShape()[1] != 1) { + return rewriter.notifyMatchFailure(op, "outer dims need to be 1"); + } + + auto outTypeRanked = out.getType().cast(); + RankedTensorType intermediateOutType = + RankedTensorType::Builder(outTypeRanked).dropDim(0); + RankedTensorType reducedOutType = + RankedTensorType::Builder(intermediateOutType).dropDim(0); + Value reducedOut = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, out, reducedOutType); + + auto lhsTypeRanked = lhs.getType().cast(); + RankedTensorType intermediateLhsType = + RankedTensorType::Builder(lhsTypeRanked).dropDim(0); + RankedTensorType reducedLhsType = + RankedTensorType::Builder(intermediateLhsType).dropDim(0); + auto reducedLhs = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, lhs, reducedLhsType); + + auto rhsTypeRanked = rhs.getType().cast(); + RankedTensorType intermediateRhsType = + RankedTensorType::Builder(rhsTypeRanked).dropDim(0); + RankedTensorType reducedRhsType = + RankedTensorType::Builder(intermediateRhsType).dropDim(0); + auto reducedRhs = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, rhs, reducedRhsType); + /* + auto getDimAsI32 = [](RewriterBase &rewriter, Location loc, Value value, + int dim) -> Value { + return rewriter.create( + loc, rewriter.getI32Type(), + rewriter.create(loc, value, dim)); + }; + Value m = getDimAsI32(rewriter, loc, reducedLhs, 0); + Value n = getDimAsI32(rewriter, loc, reducedRhs, 0); + Value k = getDimAsI32(rewriter, loc, reducedRhs, 1); + */ + Value m = rewriter.create(loc, reducedLhs, 0); + Value n = rewriter.create(loc, reducedLhs, 1); + Value k = rewriter.create(loc, reducedRhs, 0); + //Value flagsVal = rewriter.create( + // loc, rewriter.getI32IntegerAttr(flags)); + auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op); + auto fn = getFnNameAndDefAttrs("accel_matmul_f32", rewriter, targetAttr); + auto genericMicroKernelOp = rewriter.create( + loc, reducedOutType, fn.name, ValueRange{reducedLhs, reducedRhs}, + reducedOut, ValueRange{m, n, k}, + /*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs), + /*strided_outer_dims=*/rewriter.getIndexAttr(0)); + auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, genericMicroKernelOp.getResult(0), out); + op.getResults()[0].replaceAllUsesWith(insertSliceOp); + return cast( + genericMicroKernelOp.getOperation()); +} + +template +struct LowerToAccelUKernelPattern : OpRewritePattern { + LowerToAccelUKernelPattern(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(OpType op, + PatternRewriter &rewriter) const override { + FailureOr ukernelOp = + matchDAGForUKernel(rewriter, op); + if (failed(ukernelOp)) { + return rewriter.notifyMatchFailure( + op, "failed to find microkernel op to replace with"); + } + rewriter.replaceOp(op, ukernelOp.value()->getResults()); + return success(); + } +}; + +void LLVMCPULowerToAccelUKernelsPass::runOnOperation() { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + // Enabling a lowering of an op to a microkernel is a trade-off between the + // potential performance advantage of a microkernel over pure code generation + // for that op, and the potential benefits of fusions. Indeed, once an op + // lowered into a microkernel, it will never be fused at any MLIR level. + // Since microkernels are linked as bitcode, they will still undergo LTO-like + // optimization in their calling contexts, but we shouldn't expect this to + // achieve similar results as fusing structured ops. + patterns.insert>(context); + mlir::memref::populateResolveShapedTypeResultDimsPatterns(patterns); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> createLLVMCPULowerToAccelUKernelsPass() { + return std::make_unique(); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index d05fc847fd3f4..e24b2b4d88d39 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -55,6 +55,11 @@ static llvm::cl::opt clEnablePadConsumerFusion( llvm::cl::desc("Flag to enable the fusion for pad + consumer"), llvm::cl::init(false)); +static llvm::cl::opt clEnableAccelMicrokernels( + "iree-llvmcpu-enable-accel-ukernels", + llvm::cl::desc("Flag to enable lowering to accel microkernels"), + llvm::cl::init(false)); + static llvm::cl::opt clEnableMicrokernelsDecomposeLinalgGeneric( "iree-vmvx-enable-microkernels-decompose-linalg-generic", llvm::cl::desc("Enables decomposition of linalg.generic ops when " @@ -599,7 +604,21 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager, OpPassManager &nestedModulePM = passManager.nest(); - if (enableMicrokernels) { + if (clEnableAccelMicrokernels) { + nestedModulePM.addNestedPass(createLLVMCPUTileAndFusePass( + static_cast(tilingConfig.getVectorCommonParallelLevel()))); + nestedModulePM.addNestedPass(createLLVMCPUTilePass( + static_cast(tilingConfig.getVectorReductionLevel()))); + nestedModulePM.addNestedPass( + createDecomposeBatchMmt4DOpsPass()); + nestedModulePM.addPass(createLLVMCPULowerToAccelUKernelsPass()); + nestedModulePM.addNestedPass( + createConvertToDestinationPassingStylePass()); + nestedModulePM.addNestedPass( + createGenericVectorizationPass()); + nestedModulePM.addNestedPass( + createHoistRedundantVectorTransfersPass()); + } else if (enableMicrokernels) { nestedModulePM.addNestedPass( createDecomposeBatchMmt4DOpsPass()); nestedModulePM.addPass( @@ -620,7 +639,7 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager, addBufferizePasses(nestedModulePM); - if (!enableMicrokernels) { + if (!enableMicrokernels && !clEnableAccelMicrokernels) { nestedModulePM.addNestedPass( createLLVMCPUMmt4dVectorLoweringPass()); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h index 47dad29749e12..7f536ac1f4e47 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h @@ -48,6 +48,10 @@ std::unique_ptr createExpandF16OpToF32Pass(); std::unique_ptr> createLLVMCPULowerToUKernelsPass(bool skipIntermediateRoundings = true); +/// Pass to lower a sequence of operations to a iree_codegen.ukernel.* +/// operation. +std::unique_ptr> createLLVMCPULowerToAccelUKernelsPass(); + std::unique_ptr> createLLVMCPUMmt4dVectorLoweringPass(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td index 69fc43ffc07dd..da886893d5bf3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td @@ -86,6 +86,14 @@ def LLVMCPULowerToUKernels : ]; } +def LLVMCPULowerToAccelUKernels : + Pass<"iree-llvmcpu-lower-to-accel-ukernels", ""> { + let summary = + "Separate out parts of the IR that lower to an accel-micro-kernel"; + let constructor = + "mlir::iree_compiler::createLLVMCPULowerToAccelUKernelsPass()"; +} + def LLVMCPUMmt4dVectorLowering : Pass<"iree-llvmcpu-mmt4d-vector-lowering", "func::FuncOp"> { let summary = "Apply vector lowering logic to vector ops"; diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel index 398c6ac28c26b..a970a9eb7eea9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel @@ -35,6 +35,7 @@ iree_lit_test_suite( "hal_interface_constants.mlir", "hal_interface_workgroup_info.mlir", "illegal_configuration.mlir", + "lower_to_accel_ukernel_ops.mlir", "lower_to_ukernel_ops.mlir", "materialize_aarch64_launch_configuration.mlir", "materialize_configuration_without_distribution.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt index 0c2f1d9002cdd..217d6afd69c92 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt @@ -30,6 +30,7 @@ iree_lit_test_suite( "hal_interface_constants.mlir" "hal_interface_workgroup_info.mlir" "illegal_configuration.mlir" + "lower_to_accel_ukernel_ops.mlir" "lower_to_ukernel_ops.mlir" "materialize_aarch64_launch_configuration.mlir" "materialize_configuration_without_distribution.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_accel_ukernel_ops.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_accel_ukernel_ops.mlir new file mode 100644 index 0000000000000..a8b7ee5c79f64 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_accel_ukernel_ops.mlir @@ -0,0 +1,39 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-lower-to-accel-ukernels,cse,canonicalize))" %s | FileCheck %s + +func.func @mmt4d_f32f32f32(%arg0 : tensor<1x1x?x?xf32>, %arg1 : tensor<1x1x?x?xf32>, + %arg2 : tensor<1x1x?x?xf32>) -> tensor<1x1x?x?xf32> { + %0 = linalg.mmt4d ins(%arg0, %arg1 : tensor<1x1x?x?xf32>, tensor<1x1x?x?xf32>) + outs(%arg2 : tensor<1x1x?x?xf32>) -> tensor<1x1x?x?xf32> + return %0 : tensor<1x1x?x?xf32> +} + +// CHECK: func @mmt4d_f32f32f32( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x1x?x?xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x1x?x?xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<1x1x?x?xf32> +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 +// CHECK-DAG: %[[C1_i32:.+]] = arith.constant 1 +// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG2]], %[[C2]] +// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG2]], %[[C3]] +// CHECK-DAG: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG2]] +// CHECK-DAG: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] +// CHECK-DAG: %[[DIM_2:.+]] = tensor.dim %[[ARG0]], %[[C3]] +// CHECK-DAG: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-DAG: %[[DIM_4:.+]] = tensor.dim %[[ARG1]], %[[C2]] +// CHECK-DAG: %[[DIM_5:.+]] = tensor.dim %[[ARG1]], %[[C3]] +// CHECK-DAG: %[[EXTRACTED_SLICE_6:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[EXTRACTED_SLICE_3]], %[[C0]]: tensor +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[EXTRACTED_SLICE_6]], %[[C0]]: tensor +// CHECK-DAG: %[[K:.+]] = tensor.dim %[[EXTRACTED_SLICE_6]], %[[C1]]: tensor +// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "accel_matmul_f32" +// CHECK-SAME: ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_6]] : +// CHECK-SAME: outs(%[[EXTRACTED_SLICE]] : +// CHECK-SAME: (%[[M]], %[[N]], %[[K]] : +// CHECK-DAG: "processor_data" +// CHECK-DAG: "processor_id" +// CHECK-DAG: strided_outer_dims(0) +// CHECK-DAG: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[MICRO_KERNEL]] into %[[ARG2]]: tensor +// CHECK: return %[[INSERTED_SLICE]] From e08fba71fc080b931f8ccb364141fb01289d0264 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 28 Sep 2023 08:24:48 -0700 Subject: [PATCH 44/44] Take slice of FillOp and use transposed fn name. --- .../LLVMCPU/LLVMCPULowerToAccelUKernels.cpp | 40 ++++++++++++------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToAccelUKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToAccelUKernels.cpp index b8f6b5022702e..ff7314bb2b70e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToAccelUKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToAccelUKernels.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -74,7 +75,7 @@ getFnNameAndDefAttrs(const char *ukernelName, RewriterBase &rewriter, } /// Matches an (linalg.fill -> )? linalg.mmt4d operation sequence and converts -/// it into a iree_codegen.ukernel.generic "accel_matmul_f32" operation, that is +/// it into a iree_codegen.ukernel.generic "accel_matmul_t_f32" operation, that is /// later lowered into a call to the microkernel. static FailureOr matchDAGForUKernel(RewriterBase &rewriter, linalg::Mmt4DOp op) { @@ -105,14 +106,27 @@ matchDAGForUKernel(RewriterBase &rewriter, linalg::Mmt4DOp op) { if (outType.getShape()[0] != 1 || outType.getShape()[1] != 1) { return rewriter.notifyMatchFailure(op, "outer dims need to be 1"); } - auto outTypeRanked = out.getType().cast(); RankedTensorType intermediateOutType = RankedTensorType::Builder(outTypeRanked).dropDim(0); RankedTensorType reducedOutType = RankedTensorType::Builder(intermediateOutType).dropDim(0); - Value reducedOut = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, out, reducedOutType); + Value reducedOut; + Value initTensor; + if (auto oldFillOp = out.getDefiningOp()) { + initTensor = oldFillOp.output(); + auto newInit = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, initTensor, reducedOutType); + reducedOut = + rewriter + .create(loc, ValueRange{oldFillOp.value()}, + ValueRange{newInit}) + .result(); + } else { + reducedOut = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, out, reducedOutType); + initTensor = out; + } auto lhsTypeRanked = lhs.getType().cast(); RankedTensorType intermediateLhsType = @@ -146,14 +160,14 @@ matchDAGForUKernel(RewriterBase &rewriter, linalg::Mmt4DOp op) { //Value flagsVal = rewriter.create( // loc, rewriter.getI32IntegerAttr(flags)); auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op); - auto fn = getFnNameAndDefAttrs("accel_matmul_f32", rewriter, targetAttr); + auto fn = getFnNameAndDefAttrs("accel_matmul_t_f32", rewriter, targetAttr); auto genericMicroKernelOp = rewriter.create( loc, reducedOutType, fn.name, ValueRange{reducedLhs, reducedRhs}, reducedOut, ValueRange{m, n, k}, /*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs), /*strided_outer_dims=*/rewriter.getIndexAttr(0)); auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp( - rewriter, loc, genericMicroKernelOp.getResult(0), out); + rewriter, loc, genericMicroKernelOp.getResult(0), initTensor); op.getResults()[0].replaceAllUsesWith(insertSliceOp); return cast( genericMicroKernelOp.getOperation()); @@ -179,16 +193,14 @@ struct LowerToAccelUKernelPattern : OpRewritePattern { void LLVMCPULowerToAccelUKernelsPass::runOnOperation() { MLIRContext *context = &getContext(); + // Convert mmt4d ops to iree_codegen.ukernel.generic "accel_matmul_f32" ops. RewritePatternSet patterns(context); - // Enabling a lowering of an op to a microkernel is a trade-off between the - // potential performance advantage of a microkernel over pure code generation - // for that op, and the potential benefits of fusions. Indeed, once an op - // lowered into a microkernel, it will never be fused at any MLIR level. - // Since microkernels are linked as bitcode, they will still undergo LTO-like - // optimization in their calling contexts, but we shouldn't expect this to - // achieve similar results as fusing structured ops. patterns.insert>(context); - mlir::memref::populateResolveShapedTypeResultDimsPatterns(patterns); + // Canonicalize extract and insert slice ops created during the conversion. + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, context); + tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context); + //mlir::memref::populateResolveShapedTypeResultDimsPatterns(patterns); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure();