From eb87cf7e0143b3ca39498b9871c16a41cc278cf0 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Thu, 28 Nov 2024 05:58:41 +0000 Subject: [PATCH 01/10] [Codegen][Common] Add a pass to linearize memrefs -- This commit creates a pass to linearize memrefs. -- The pass `iree-linearize-memrefs` will be iteratively worked upon to make it an inter-procedural pass. -- Currently it supports limited operations. Signed-off-by: Abhishek Varma --- .../iree/compiler/Codegen/Common/BUILD.bazel | 1 + .../compiler/Codegen/Common/CMakeLists.txt | 1 + .../Codegen/Common/LinearizeMemRefs.cpp | 343 ++++++++++++++++++ .../iree/compiler/Codegen/Common/Passes.td | 15 + .../compiler/Codegen/Common/test/BUILD.bazel | 1 + .../Codegen/Common/test/CMakeLists.txt | 1 + .../Common/test/linearize_memrefs.mlir | 65 ++++ 7 files changed, 427 insertions(+) create mode 100644 compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 8f7cb459a6d5..11c7f101b1eb 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -123,6 +123,7 @@ iree_compiler_cc_library( "IREEExpandStridedMetadata.cpp", "IREELoopInvariantCodeMotion.cpp", "InstrumentMemoryAccesses.cpp", + "LinearizeMemRefs.cpp", "LinkTuningSpecsPass.cpp", "LowerExecutableUsingTransformDialect.cpp", "LowerUKernelsToCalls.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 1b67c5db261e..302055e8b934 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -115,6 +115,7 @@ iree_cc_library( "IREEExpandStridedMetadata.cpp" "IREELoopInvariantCodeMotion.cpp" "InstrumentMemoryAccesses.cpp" + "LinearizeMemRefs.cpp" "LinkTuningSpecsPass.cpp" "LowerExecutableUsingTransformDialect.cpp" "LowerUKernelsToCalls.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp new file mode 100644 index 000000000000..399369247102 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp @@ -0,0 +1,343 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===- LinearizeMemRefs.cpp - Flatten n-D MemRef subspan ------------------===// +// +// This file implements an interprocedural pass to linearize memrefs. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-linearize-memrefs" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_LINEARIZEMEMREFS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +namespace { + +static SmallVector getLinearizedShape(MemRefType ty, int srcBits, + int dstBits) { + if (ty.getRank() == 0) + return {}; + + int64_t linearizedShape = 1; + for (auto shape : ty.getShape()) { + if (shape == ShapedType::kDynamic) + return {ShapedType::kDynamic}; + linearizedShape *= shape; + } + int scale = dstBits / srcBits; + // Scale the size to the ceilDiv(linearizedShape, scale) + // to accomodate all the values. + linearizedShape = (linearizedShape + scale - 1) / scale; + return {linearizedShape}; +} + +static LogicalResult linearizeType(MemRefType memrefType, + MemRefType &newMemrefType) { + // Fetch linearized shape. + // TODO(avarma): Take into account different src/dst bits. + int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + SmallVector linearizedShape = + getLinearizedShape(memrefType, srcBits, srcBits); + // Fetch offset and strides of the old memref. + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(memrefType, strides, offset))) + return failure(); + if (!strides.empty() && strides.back() != 1) + return failure(); + // Form layout for the linearized memref. + StridedLayoutAttr layoutAttr; + // If the offset is 0, we do not need a strided layout as the stride is + // 1, so we only use the strided layout if the offset is not 0. + if (offset != 0) { + layoutAttr = StridedLayoutAttr::get(memrefType.getContext(), offset, + ArrayRef{1}); + } + Type elementType = memrefType.getElementType(); + newMemrefType = MemRefType::get(linearizedShape, elementType, layoutAttr, + memrefType.getMemorySpace()); + return success(); +} + +static LogicalResult +getLinearizedTypeFromSourceType(MemRefType currentTypeOfSourceMemref, + MemRefType &linearizedType) { + if (!currentTypeOfSourceMemref) + return failure(); + if (currentTypeOfSourceMemref.getRank() < 2) + return success(); + // Convert current type later. + return linearizeType(currentTypeOfSourceMemref, linearizedType); +} + +template +struct LinearizeMemrefAlloc : public OpRewritePattern { + LinearizeMemrefAlloc(MLIRContext *context, PatternBenefit benefit = 10) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(OpTy allocOp, + PatternRewriter &rewriter) const override { + static_assert(std::is_same() || + std::is_same(), + "expected only memref::AllocOp or memref::AllocaOp"); + Location loc = allocOp->getLoc(); + MemRefType currentTypeOfSourceMemref = + dyn_cast(allocOp.getMemref().getType()); + MemRefType newTypeOfSourceMemref; + if (failed(getLinearizedTypeFromSourceType(currentTypeOfSourceMemref, + newTypeOfSourceMemref))) { + return failure(); + } + if (currentTypeOfSourceMemref.getRank() < 2) + return success(); + + auto elementType = currentTypeOfSourceMemref.getElementType(); + int srcBits = elementType.getIntOrFloatBitWidth(); + + OpFoldResult zero = rewriter.getIndexAttr(0); + + // Get linearized type. + int dstBits = srcBits; + SmallVector sizes = allocOp.getMixedSizes(); + + memref::LinearizedMemRefInfo linearizedMemRefInfo = + memref::getLinearizedMemRefOffsetAndSize( + rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes); + SmallVector dynamicLinearizedSize; + if (!newTypeOfSourceMemref.hasStaticShape()) { + dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp( + rewriter, loc, linearizedMemRefInfo.linearizedSize)); + } + + rewriter.replaceOpWithNewOp( + allocOp, newTypeOfSourceMemref, dynamicLinearizedSize, + allocOp.getSymbolOperands(), allocOp.getAlignmentAttr()); + return success(); + } +}; + +static Value linearizeOperand(Location loc, PatternRewriter &rewriter, + Value operand, MemRefType linearizedType) { + return rewriter.create( + loc, linearizedType, operand, 0, linearizedType.getShape(), + ArrayRef({1})); +} + +struct LinearizeMemrefLoad : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::LoadOp loadOp, + PatternRewriter &rewriter) const override { + Location loc = loadOp->getLoc(); + MemRefType currentTypeOfSourceMemref = loadOp.getMemRefType(); + MemRefType newTypeOfSourceMemref; + if (failed(getLinearizedTypeFromSourceType(currentTypeOfSourceMemref, + newTypeOfSourceMemref))) { + return failure(); + } + if (currentTypeOfSourceMemref.getRank() < 2 && + loadOp.getIndices().size() < 2) + return success(); + + Value linearizedIndices = rewriter.create( + loc, loadOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); + Value linearizedOperand = linearizeOperand( + loc, rewriter, loadOp.getMemref(), newTypeOfSourceMemref); + Value linearizedLoad = rewriter.create( + loc, linearizedOperand, linearizedIndices); + + rewriter.replaceOp(loadOp, {linearizedLoad}); + return success(); + } +}; + +struct LinearizeMemrefStore : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::StoreOp storeOp, + PatternRewriter &rewriter) const override { + Location loc = storeOp->getLoc(); + MemRefType currentTypeOfSourceMemref = storeOp.getMemRefType(); + MemRefType newTypeOfSourceMemref; + if (failed(getLinearizedTypeFromSourceType(currentTypeOfSourceMemref, + newTypeOfSourceMemref))) { + return failure(); + } + if (currentTypeOfSourceMemref.getRank() < 2 && + storeOp.getIndices().size() < 2) + return success(); + + auto elementType = storeOp.getMemRefType().getElementType(); + int srcBits = elementType.getIntOrFloatBitWidth(); + Value linearizedIndices = rewriter.create( + loc, storeOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); + Value linearizedOperand = linearizeOperand( + loc, rewriter, storeOp.getMemref(), newTypeOfSourceMemref); + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getValueToStore(), linearizedOperand, + linearizedIndices, srcBits); + + return success(); + } +}; + +struct LinearizeMemrefDealloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::DeallocOp deallocOp, + PatternRewriter &rewriter) const override { + Location loc = deallocOp->getLoc(); + MemRefType currentTypeOfSourceMemref = + dyn_cast(deallocOp.getMemref().getType()); + MemRefType newTypeOfSourceMemref; + if (failed(getLinearizedTypeFromSourceType(currentTypeOfSourceMemref, + newTypeOfSourceMemref))) { + return failure(); + } + if (currentTypeOfSourceMemref.getRank() < 2) + return success(); + + Value linearizedOperand = linearizeOperand( + loc, rewriter, deallocOp.getMemref(), newTypeOfSourceMemref); + + rewriter.replaceOpWithNewOp(deallocOp, + linearizedOperand); + return success(); + } +}; + +struct LinearizeMemrefCopy : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CopyOp copyOp, + PatternRewriter &rewriter) const override { + Location loc = copyOp->getLoc(); + MemRefType currentTypeOfSourceMemref = + dyn_cast(copyOp.getSource().getType()); + MemRefType currentTypeOfTargetMemref = + dyn_cast(copyOp.getTarget().getType()); + MemRefType newTypeOfSourceMemref; + if (failed(getLinearizedTypeFromSourceType(currentTypeOfSourceMemref, + newTypeOfSourceMemref))) { + return failure(); + } + if (currentTypeOfSourceMemref.getRank() < 2 && + currentTypeOfTargetMemref.getRank() < 2) + return success(); + + Value linearizedSource = linearizeOperand(loc, rewriter, copyOp.getSource(), + newTypeOfSourceMemref); + Value linearizedTarget = linearizeOperand(loc, rewriter, copyOp.getTarget(), + newTypeOfSourceMemref); + + rewriter.replaceOpWithNewOp(copyOp, linearizedSource, + linearizedTarget); + return success(); + } +}; + +struct LinearizeVectorLoad : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::LoadOp loadOp, + PatternRewriter &rewriter) const override { + Location loc = loadOp->getLoc(); + MemRefType currentTypeOfSourceMemref = loadOp.getMemRefType(); + MemRefType newTypeOfSourceMemref; + if (failed(getLinearizedTypeFromSourceType(currentTypeOfSourceMemref, + newTypeOfSourceMemref))) { + return failure(); + } + if (currentTypeOfSourceMemref.getRank() < 2 && + loadOp.getIndices().size() < 2) + return success(); + + Value linearizedIndices = rewriter.create( + loc, loadOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); + Value linearizedOperand = linearizeOperand(loc, rewriter, loadOp.getBase(), + newTypeOfSourceMemref); + Value linearizedLoad = rewriter.create( + loc, loadOp.getType(), linearizedOperand, linearizedIndices); + + rewriter.replaceOp(loadOp, {linearizedLoad}); + return success(); + } +}; + +struct LinearizeVectorStore : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::StoreOp storeOp, + PatternRewriter &rewriter) const override { + Location loc = storeOp->getLoc(); + MemRefType currentTypeOfSourceMemref = storeOp.getMemRefType(); + MemRefType newTypeOfSourceMemref; + if (failed(getLinearizedTypeFromSourceType(currentTypeOfSourceMemref, + newTypeOfSourceMemref))) { + return failure(); + } + if (currentTypeOfSourceMemref.getRank() < 2 && + storeOp.getIndices().size() < 2) + return success(); + + Value linearizedIndices = rewriter.create( + loc, storeOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); + Value linearizedOperand = linearizeOperand(loc, rewriter, storeOp.getBase(), + newTypeOfSourceMemref); + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getValueToStore(), linearizedOperand, + linearizedIndices); + + return success(); + } +}; +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +struct LinearizeMemRefs final : impl::LinearizeMemRefsBase { + void runOnOperation() override; +}; + +void LinearizeMemRefs::runOnOperation() { + LLVM_DEBUG(llvm::dbgs() << "Linearizing Memrefs...\n"); + ModuleOp moduleOp = getOperation(); + MLIRContext *context = &getContext(); + IRRewriter rewriter(context); + + RewritePatternSet patterns(context); + patterns.add>(context); + patterns.add>(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + + (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); + + return; +} +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 811fa9ccc588..8d9384dabb91 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -411,6 +411,21 @@ def InstrumentMemoryAccessesPass : let summary = "Instruments memory reads and writes for address tracking when dispatch instrumentation is enabled."; } +def LinearizeMemRefs : Pass<"iree-linearize-memrefs", "ModuleOp"> { + let summary = + "An inter-procedural pass to linearize memrefs"; + let description = [{ + An inter-procedural pass to linearize memrefs. + Currently operates on :- + 1. memref.load/store + 2. vector.load/store + 3. memref.alloc* + 4. memref.dealloc + 5. memref.copy + }]; + let dependentDialects = ["affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"]; +} + def LinkTuningSpecsPass : Pass<"iree-codegen-link-tuning-specs", "ModuleOp"> { let summary = "Link nested transform dialect tuning specs named sequences into a single entry point"; diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index 10eba99c02ed..1e4aeaca649c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -57,6 +57,7 @@ iree_lit_test_suite( "iree_comprehensive_bufferize.mlir", "iree_expand_strided_metadata.mlir", "iree_loop_invariant_code_motion.mlir", + "linearize_memrefs.mlir", "link_tuning_specs.mlir", "llvmcpu_materialize_encoding.mlir", "lower_ukernel_to_calls.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index 69c26592ff3c..05967a929321 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -53,6 +53,7 @@ iree_lit_test_suite( "iree_comprehensive_bufferize.mlir" "iree_expand_strided_metadata.mlir" "iree_loop_invariant_code_motion.mlir" + "linearize_memrefs.mlir" "link_tuning_specs.mlir" "llvmcpu_materialize_encoding.mlir" "lower_ukernel_to_calls.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir b/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir new file mode 100644 index 000000000000..05484ab8d893 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir @@ -0,0 +1,65 @@ +// RUN: iree-opt -iree-linearize-memrefs -allow-unregistered-dialect %s | FileCheck %s + +// CHECK-LABEL: @vector_load_store( +// CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xi32>) +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[ARG0]] to +// CHECK-SAME: offset: [0], sizes: [24], strides: [1] : +// CHECK-SAME: memref<2x3x4xi32> to memref<24xi32> +// CHECK: %[[LOAD:.*]] = vector.load %[[CAST]][%[[C6]]] +// CHECK: %[[CAST_2:.*]] = memref.reinterpret_cast %[[ARG0]] to +// CHECK-SAME: offset: [0], sizes: [24], strides: [1] : +// CHECK-SAME: memref<2x3x4xi32> to memref<24xi32> +// CHECK: vector.store %[[LOAD]], %[[CAST_2]][%[[C6]]] +// CHECK: return %[[LOAD]] +func.func @vector_load_store(%arg0: memref<2x3x4xi32>) -> vector<2xi32> { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %1 = vector.load %arg0[%c0, %c1, %c2] : memref<2x3x4xi32>, vector<2xi32> + vector.store %1, %arg0[%c0, %c1, %c2] : memref<2x3x4xi32>, vector<2xi32> + return %1 : vector<2xi32> +} + +// ----- + +// CHECK-LABEL: @memref_load_store_alloc_dealloc( +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<24xi32> +// CHECK: %[[LOAD:.*]] = memref.load %[[ALLOC]][%[[C6]]] +// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<60xi32> +// CHECK: memref.store %[[LOAD]], %[[ALLOCA]][%[[C7]]] {nontemporal = true} : memref<60xi32> +// CHECK: memref.dealloc %[[ALLOC]] +// CHECK: memref.dealloc %[[ALLOCA]] +// CHECK: return %[[LOAD]] +func.func @memref_load_store_alloc_dealloc() -> i32 { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = memref.alloc() : memref<2x3x4xi32> + %1 = memref.load %0[%c0, %c1, %c2] : memref<2x3x4xi32> + %2 = memref.alloca() : memref<3x4x5xi32> + memref.store %1, %2[%c0, %c1, %c2] : memref<3x4x5xi32> + memref.dealloc %0 : memref<2x3x4xi32> + memref.dealloc %2 : memref<3x4x5xi32> + return %1 : i32 +} + +// ----- + +// CHECK-LABEL: @memref_copy( +// CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xi32>, +// CHECK-SAME: %[[ARG1:.*]]: memref<2x3x4xi32>) +// CHECK: %[[CAST_1:.*]] = memref.reinterpret_cast %[[ARG0]] to +// CHECK-SAME: offset: [0], sizes: [24], strides: [1] : +// CHECK-SAME: memref<2x3x4xi32> to memref<24xi32> +// CHECK: %[[CAST_2:.*]] = memref.reinterpret_cast %[[ARG1]] to +// CHECK-SAME: offset: [0], sizes: [24], strides: [1] : +// CHECK-SAME: memref<2x3x4xi32> to memref<24xi32> +// CHECK: memref.copy %[[CAST_1]], %[[CAST_2]] +// CHECK: return +func.func @memref_copy(%arg0: memref<2x3x4xi32>, %arg1: memref<2x3x4xi32>) { + memref.copy %arg0, %arg1 : memref<2x3x4xi32> to memref<2x3x4xi32> + return +} From 144e87c7a31537dae4c48101764ee1abdc127047 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 4 Dec 2024 07:35:03 +0000 Subject: [PATCH 02/10] Review comment v1.0 Signed-off-by: Abhishek Varma --- .../Codegen/Common/LinearizeMemRefs.cpp | 135 ++++++++---------- 1 file changed, 62 insertions(+), 73 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp index 399369247102..5fc30683c34e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp @@ -33,31 +33,22 @@ namespace mlir::iree_compiler { namespace { -static SmallVector getLinearizedShape(MemRefType ty, int srcBits, - int dstBits) { - if (ty.getRank() == 0) +static SmallVector getLinearizedShape(MemRefType type) { + if (type.getRank() == 0) return {}; int64_t linearizedShape = 1; - for (auto shape : ty.getShape()) { + for (auto shape : type.getShape()) { if (shape == ShapedType::kDynamic) return {ShapedType::kDynamic}; linearizedShape *= shape; } - int scale = dstBits / srcBits; - // Scale the size to the ceilDiv(linearizedShape, scale) - // to accomodate all the values. - linearizedShape = (linearizedShape + scale - 1) / scale; return {linearizedShape}; } -static LogicalResult linearizeType(MemRefType memrefType, - MemRefType &newMemrefType) { +static FailureOr linearizeType(MemRefType memrefType) { // Fetch linearized shape. - // TODO(avarma): Take into account different src/dst bits. - int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); - SmallVector linearizedShape = - getLinearizedShape(memrefType, srcBits, srcBits); + SmallVector linearizedShape = getLinearizedShape(memrefType); // Fetch offset and strides of the old memref. SmallVector strides; int64_t offset; @@ -74,20 +65,16 @@ static LogicalResult linearizeType(MemRefType memrefType, ArrayRef{1}); } Type elementType = memrefType.getElementType(); - newMemrefType = MemRefType::get(linearizedShape, elementType, layoutAttr, - memrefType.getMemorySpace()); - return success(); + return MemRefType::get(linearizedShape, elementType, layoutAttr, + memrefType.getMemorySpace()); } -static LogicalResult -getLinearizedTypeFromSourceType(MemRefType currentTypeOfSourceMemref, - MemRefType &linearizedType) { +static FailureOr +getLinearizedTypeFromSourceType(MemRefType currentTypeOfSourceMemref) { if (!currentTypeOfSourceMemref) return failure(); - if (currentTypeOfSourceMemref.getRank() < 2) - return success(); // Convert current type later. - return linearizeType(currentTypeOfSourceMemref, linearizedType); + return linearizeType(currentTypeOfSourceMemref); } template @@ -103,30 +90,32 @@ struct LinearizeMemrefAlloc : public OpRewritePattern { Location loc = allocOp->getLoc(); MemRefType currentTypeOfSourceMemref = dyn_cast(allocOp.getMemref().getType()); - MemRefType newTypeOfSourceMemref; - if (failed(getLinearizedTypeFromSourceType(currentTypeOfSourceMemref, - newTypeOfSourceMemref))) { - return failure(); - } if (currentTypeOfSourceMemref.getRank() < 2) return success(); - - auto elementType = currentTypeOfSourceMemref.getElementType(); - int srcBits = elementType.getIntOrFloatBitWidth(); - - OpFoldResult zero = rewriter.getIndexAttr(0); - - // Get linearized type. - int dstBits = srcBits; - SmallVector sizes = allocOp.getMixedSizes(); - - memref::LinearizedMemRefInfo linearizedMemRefInfo = - memref::getLinearizedMemRefOffsetAndSize( - rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes); + FailureOr maybeNewTypeOfSourceMemref = + getLinearizedTypeFromSourceType(currentTypeOfSourceMemref); + if (failed(maybeNewTypeOfSourceMemref)) + return failure(); + MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; + + SmallVector sizes = + llvm::map_to_vector(allocOp.getMixedSizes(), [&](OpFoldResult size) { + Value val; + if (dyn_cast(size)) { + val = rewriter.create( + loc, *getConstantIntValue(size)); + } else { + val = cast(size); + } + return val; + }); + + Value linearizedSize = rewriter.create( + loc, sizes, currentTypeOfSourceMemref.getShape(), true); SmallVector dynamicLinearizedSize; if (!newTypeOfSourceMemref.hasStaticShape()) { - dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp( - rewriter, loc, linearizedMemRefInfo.linearizedSize)); + dynamicLinearizedSize.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, linearizedSize)); } rewriter.replaceOpWithNewOp( @@ -150,14 +139,14 @@ struct LinearizeMemrefLoad : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = loadOp->getLoc(); MemRefType currentTypeOfSourceMemref = loadOp.getMemRefType(); - MemRefType newTypeOfSourceMemref; - if (failed(getLinearizedTypeFromSourceType(currentTypeOfSourceMemref, - newTypeOfSourceMemref))) { - return failure(); - } if (currentTypeOfSourceMemref.getRank() < 2 && loadOp.getIndices().size() < 2) return success(); + FailureOr maybeNewTypeOfSourceMemref = + getLinearizedTypeFromSourceType(currentTypeOfSourceMemref); + if (failed(maybeNewTypeOfSourceMemref)) + return failure(); + MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; Value linearizedIndices = rewriter.create( loc, loadOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); @@ -178,14 +167,14 @@ struct LinearizeMemrefStore : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = storeOp->getLoc(); MemRefType currentTypeOfSourceMemref = storeOp.getMemRefType(); - MemRefType newTypeOfSourceMemref; - if (failed(getLinearizedTypeFromSourceType(currentTypeOfSourceMemref, - newTypeOfSourceMemref))) { - return failure(); - } if (currentTypeOfSourceMemref.getRank() < 2 && storeOp.getIndices().size() < 2) return success(); + FailureOr maybeNewTypeOfSourceMemref = + getLinearizedTypeFromSourceType(currentTypeOfSourceMemref); + if (failed(maybeNewTypeOfSourceMemref)) + return failure(); + MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; auto elementType = storeOp.getMemRefType().getElementType(); int srcBits = elementType.getIntOrFloatBitWidth(); @@ -209,13 +198,13 @@ struct LinearizeMemrefDealloc : public OpRewritePattern { Location loc = deallocOp->getLoc(); MemRefType currentTypeOfSourceMemref = dyn_cast(deallocOp.getMemref().getType()); - MemRefType newTypeOfSourceMemref; - if (failed(getLinearizedTypeFromSourceType(currentTypeOfSourceMemref, - newTypeOfSourceMemref))) { - return failure(); - } if (currentTypeOfSourceMemref.getRank() < 2) return success(); + FailureOr maybeNewTypeOfSourceMemref = + getLinearizedTypeFromSourceType(currentTypeOfSourceMemref); + if (failed(maybeNewTypeOfSourceMemref)) + return failure(); + MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; Value linearizedOperand = linearizeOperand( loc, rewriter, deallocOp.getMemref(), newTypeOfSourceMemref); @@ -236,14 +225,14 @@ struct LinearizeMemrefCopy : public OpRewritePattern { dyn_cast(copyOp.getSource().getType()); MemRefType currentTypeOfTargetMemref = dyn_cast(copyOp.getTarget().getType()); - MemRefType newTypeOfSourceMemref; - if (failed(getLinearizedTypeFromSourceType(currentTypeOfSourceMemref, - newTypeOfSourceMemref))) { - return failure(); - } if (currentTypeOfSourceMemref.getRank() < 2 && currentTypeOfTargetMemref.getRank() < 2) return success(); + FailureOr maybeNewTypeOfSourceMemref = + getLinearizedTypeFromSourceType(currentTypeOfSourceMemref); + if (failed(maybeNewTypeOfSourceMemref)) + return failure(); + MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; Value linearizedSource = linearizeOperand(loc, rewriter, copyOp.getSource(), newTypeOfSourceMemref); @@ -263,14 +252,14 @@ struct LinearizeVectorLoad : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = loadOp->getLoc(); MemRefType currentTypeOfSourceMemref = loadOp.getMemRefType(); - MemRefType newTypeOfSourceMemref; - if (failed(getLinearizedTypeFromSourceType(currentTypeOfSourceMemref, - newTypeOfSourceMemref))) { - return failure(); - } if (currentTypeOfSourceMemref.getRank() < 2 && loadOp.getIndices().size() < 2) return success(); + FailureOr maybeNewTypeOfSourceMemref = + getLinearizedTypeFromSourceType(currentTypeOfSourceMemref); + if (failed(maybeNewTypeOfSourceMemref)) + return failure(); + MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; Value linearizedIndices = rewriter.create( loc, loadOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); @@ -291,14 +280,14 @@ struct LinearizeVectorStore : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = storeOp->getLoc(); MemRefType currentTypeOfSourceMemref = storeOp.getMemRefType(); - MemRefType newTypeOfSourceMemref; - if (failed(getLinearizedTypeFromSourceType(currentTypeOfSourceMemref, - newTypeOfSourceMemref))) { - return failure(); - } if (currentTypeOfSourceMemref.getRank() < 2 && storeOp.getIndices().size() < 2) return success(); + FailureOr maybeNewTypeOfSourceMemref = + getLinearizedTypeFromSourceType(currentTypeOfSourceMemref); + if (failed(maybeNewTypeOfSourceMemref)) + return failure(); + MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; Value linearizedIndices = rewriter.create( loc, storeOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); From d8b8357d47eb882b145a7389022bd10cdecb804f Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 4 Dec 2024 16:12:28 +0000 Subject: [PATCH 03/10] Review comments v1.1 Signed-off-by: Abhishek Varma --- .../Codegen/Common/LinearizeMemRefs.cpp | 79 +++++++++---------- .../Common/test/linearize_memrefs.mlir | 8 +- 2 files changed, 42 insertions(+), 45 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp index 5fc30683c34e..3c781e9db1ab 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp @@ -77,10 +77,17 @@ getLinearizedTypeFromSourceType(MemRefType currentTypeOfSourceMemref) { return linearizeType(currentTypeOfSourceMemref); } +static Value reshapeOperand(Location loc, PatternRewriter &rewriter, + Value operand, MemRefType linearizedType) { + return rewriter.create( + loc, linearizedType, operand, 0, linearizedType.getShape(), + ArrayRef({1})); +} + template struct LinearizeMemrefAlloc : public OpRewritePattern { - LinearizeMemrefAlloc(MLIRContext *context, PatternBenefit benefit = 10) - : OpRewritePattern(context, benefit) {} + LinearizeMemrefAlloc(MLIRContext *context) + : OpRewritePattern(context) {} LogicalResult matchAndRewrite(OpTy allocOp, PatternRewriter &rewriter) const override { @@ -98,40 +105,30 @@ struct LinearizeMemrefAlloc : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - SmallVector sizes = - llvm::map_to_vector(allocOp.getMixedSizes(), [&](OpFoldResult size) { - Value val; - if (dyn_cast(size)) { - val = rewriter.create( - loc, *getConstantIntValue(size)); - } else { - val = cast(size); - } - return val; - }); - - Value linearizedSize = rewriter.create( - loc, sizes, currentTypeOfSourceMemref.getShape(), true); + int srcBits = + currentTypeOfSourceMemref.getElementType().getIntOrFloatBitWidth(); + int dstBits = srcBits; + OpFoldResult zero = rewriter.getIndexAttr(0); + + memref::LinearizedMemRefInfo linearizedMemRefInfo = + memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, srcBits, + dstBits, /*offset =*/zero, + allocOp.getMixedSizes()); SmallVector dynamicLinearizedSize; if (!newTypeOfSourceMemref.hasStaticShape()) { - dynamicLinearizedSize.push_back( - getValueOrCreateConstantIndexOp(rewriter, loc, linearizedSize)); + dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp( + rewriter, loc, linearizedMemRefInfo.linearizedSize)); } - - rewriter.replaceOpWithNewOp( - allocOp, newTypeOfSourceMemref, dynamicLinearizedSize, + Value linearizedOp = rewriter.create( + loc, newTypeOfSourceMemref, dynamicLinearizedSize, allocOp.getSymbolOperands(), allocOp.getAlignmentAttr()); + Value delinearizedOp = + reshapeOperand(loc, rewriter, linearizedOp, currentTypeOfSourceMemref); + rewriter.replaceOp(allocOp, delinearizedOp); return success(); } }; -static Value linearizeOperand(Location loc, PatternRewriter &rewriter, - Value operand, MemRefType linearizedType) { - return rewriter.create( - loc, linearizedType, operand, 0, linearizedType.getShape(), - ArrayRef({1})); -} - struct LinearizeMemrefLoad : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -150,8 +147,8 @@ struct LinearizeMemrefLoad : public OpRewritePattern { Value linearizedIndices = rewriter.create( loc, loadOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); - Value linearizedOperand = linearizeOperand( - loc, rewriter, loadOp.getMemref(), newTypeOfSourceMemref); + Value linearizedOperand = reshapeOperand(loc, rewriter, loadOp.getMemref(), + newTypeOfSourceMemref); Value linearizedLoad = rewriter.create( loc, linearizedOperand, linearizedIndices); @@ -180,8 +177,8 @@ struct LinearizeMemrefStore : public OpRewritePattern { int srcBits = elementType.getIntOrFloatBitWidth(); Value linearizedIndices = rewriter.create( loc, storeOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); - Value linearizedOperand = linearizeOperand( - loc, rewriter, storeOp.getMemref(), newTypeOfSourceMemref); + Value linearizedOperand = reshapeOperand(loc, rewriter, storeOp.getMemref(), + newTypeOfSourceMemref); rewriter.replaceOpWithNewOp( storeOp, storeOp.getValueToStore(), linearizedOperand, linearizedIndices, srcBits); @@ -206,7 +203,7 @@ struct LinearizeMemrefDealloc : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - Value linearizedOperand = linearizeOperand( + Value linearizedOperand = reshapeOperand( loc, rewriter, deallocOp.getMemref(), newTypeOfSourceMemref); rewriter.replaceOpWithNewOp(deallocOp, @@ -234,10 +231,10 @@ struct LinearizeMemrefCopy : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - Value linearizedSource = linearizeOperand(loc, rewriter, copyOp.getSource(), - newTypeOfSourceMemref); - Value linearizedTarget = linearizeOperand(loc, rewriter, copyOp.getTarget(), - newTypeOfSourceMemref); + Value linearizedSource = reshapeOperand(loc, rewriter, copyOp.getSource(), + newTypeOfSourceMemref); + Value linearizedTarget = reshapeOperand(loc, rewriter, copyOp.getTarget(), + newTypeOfSourceMemref); rewriter.replaceOpWithNewOp(copyOp, linearizedSource, linearizedTarget); @@ -263,8 +260,8 @@ struct LinearizeVectorLoad : public OpRewritePattern { Value linearizedIndices = rewriter.create( loc, loadOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); - Value linearizedOperand = linearizeOperand(loc, rewriter, loadOp.getBase(), - newTypeOfSourceMemref); + Value linearizedOperand = + reshapeOperand(loc, rewriter, loadOp.getBase(), newTypeOfSourceMemref); Value linearizedLoad = rewriter.create( loc, loadOp.getType(), linearizedOperand, linearizedIndices); @@ -291,8 +288,8 @@ struct LinearizeVectorStore : public OpRewritePattern { Value linearizedIndices = rewriter.create( loc, storeOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); - Value linearizedOperand = linearizeOperand(loc, rewriter, storeOp.getBase(), - newTypeOfSourceMemref); + Value linearizedOperand = + reshapeOperand(loc, rewriter, storeOp.getBase(), newTypeOfSourceMemref); rewriter.replaceOpWithNewOp( storeOp, storeOp.getValueToStore(), linearizedOperand, linearizedIndices); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir b/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir index 05484ab8d893..dabdba15322c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir @@ -12,13 +12,13 @@ // CHECK-SAME: memref<2x3x4xi32> to memref<24xi32> // CHECK: vector.store %[[LOAD]], %[[CAST_2]][%[[C6]]] // CHECK: return %[[LOAD]] -func.func @vector_load_store(%arg0: memref<2x3x4xi32>) -> vector<2xi32> { +func.func @vector_load_store(%arg0: memref<2x3x4xi32>) -> vector<2x3xi32> { %c2 = arith.constant 2 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index - %1 = vector.load %arg0[%c0, %c1, %c2] : memref<2x3x4xi32>, vector<2xi32> - vector.store %1, %arg0[%c0, %c1, %c2] : memref<2x3x4xi32>, vector<2xi32> - return %1 : vector<2xi32> + %1 = vector.load %arg0[%c0, %c1, %c2] : memref<2x3x4xi32>, vector<2x3xi32> + vector.store %1, %arg0[%c0, %c1, %c2] : memref<2x3x4xi32>, vector<2x3xi32> + return %1 : vector<2x3xi32> } // ----- From 26cb7c7b91d0c0722688e44f94d44ee8121cc72f Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Thu, 5 Dec 2024 07:49:38 +0000 Subject: [PATCH 04/10] Use expand_shape at the boundary of allocOp Signed-off-by: Abhishek Varma --- .../Codegen/Common/LinearizeMemRefs.cpp | 37 ++++++++++--------- .../Common/test/linearize_memrefs.mlir | 20 ++++++++-- 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp index 3c781e9db1ab..a1eb4348b432 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp @@ -77,8 +77,8 @@ getLinearizedTypeFromSourceType(MemRefType currentTypeOfSourceMemref) { return linearizeType(currentTypeOfSourceMemref); } -static Value reshapeOperand(Location loc, PatternRewriter &rewriter, - Value operand, MemRefType linearizedType) { +static Value linearizeOperand(Location loc, PatternRewriter &rewriter, + Value operand, MemRefType linearizedType) { return rewriter.create( loc, linearizedType, operand, 0, linearizedType.getShape(), ArrayRef({1})); @@ -122,8 +122,11 @@ struct LinearizeMemrefAlloc : public OpRewritePattern { Value linearizedOp = rewriter.create( loc, newTypeOfSourceMemref, dynamicLinearizedSize, allocOp.getSymbolOperands(), allocOp.getAlignmentAttr()); - Value delinearizedOp = - reshapeOperand(loc, rewriter, linearizedOp, currentTypeOfSourceMemref); + SmallVector indices(currentTypeOfSourceMemref.getRank()); + std::iota(indices.begin(), indices.end(), 0); + Value delinearizedOp = rewriter.create( + loc, currentTypeOfSourceMemref, linearizedOp, ArrayRef({indices}), + allocOp.getMixedSizes()); rewriter.replaceOp(allocOp, delinearizedOp); return success(); } @@ -147,8 +150,8 @@ struct LinearizeMemrefLoad : public OpRewritePattern { Value linearizedIndices = rewriter.create( loc, loadOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); - Value linearizedOperand = reshapeOperand(loc, rewriter, loadOp.getMemref(), - newTypeOfSourceMemref); + Value linearizedOperand = linearizeOperand( + loc, rewriter, loadOp.getMemref(), newTypeOfSourceMemref); Value linearizedLoad = rewriter.create( loc, linearizedOperand, linearizedIndices); @@ -177,8 +180,8 @@ struct LinearizeMemrefStore : public OpRewritePattern { int srcBits = elementType.getIntOrFloatBitWidth(); Value linearizedIndices = rewriter.create( loc, storeOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); - Value linearizedOperand = reshapeOperand(loc, rewriter, storeOp.getMemref(), - newTypeOfSourceMemref); + Value linearizedOperand = linearizeOperand( + loc, rewriter, storeOp.getMemref(), newTypeOfSourceMemref); rewriter.replaceOpWithNewOp( storeOp, storeOp.getValueToStore(), linearizedOperand, linearizedIndices, srcBits); @@ -203,7 +206,7 @@ struct LinearizeMemrefDealloc : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - Value linearizedOperand = reshapeOperand( + Value linearizedOperand = linearizeOperand( loc, rewriter, deallocOp.getMemref(), newTypeOfSourceMemref); rewriter.replaceOpWithNewOp(deallocOp, @@ -231,10 +234,10 @@ struct LinearizeMemrefCopy : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - Value linearizedSource = reshapeOperand(loc, rewriter, copyOp.getSource(), - newTypeOfSourceMemref); - Value linearizedTarget = reshapeOperand(loc, rewriter, copyOp.getTarget(), - newTypeOfSourceMemref); + Value linearizedSource = linearizeOperand(loc, rewriter, copyOp.getSource(), + newTypeOfSourceMemref); + Value linearizedTarget = linearizeOperand(loc, rewriter, copyOp.getTarget(), + newTypeOfSourceMemref); rewriter.replaceOpWithNewOp(copyOp, linearizedSource, linearizedTarget); @@ -260,8 +263,8 @@ struct LinearizeVectorLoad : public OpRewritePattern { Value linearizedIndices = rewriter.create( loc, loadOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); - Value linearizedOperand = - reshapeOperand(loc, rewriter, loadOp.getBase(), newTypeOfSourceMemref); + Value linearizedOperand = linearizeOperand(loc, rewriter, loadOp.getBase(), + newTypeOfSourceMemref); Value linearizedLoad = rewriter.create( loc, loadOp.getType(), linearizedOperand, linearizedIndices); @@ -288,8 +291,8 @@ struct LinearizeVectorStore : public OpRewritePattern { Value linearizedIndices = rewriter.create( loc, storeOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); - Value linearizedOperand = - reshapeOperand(loc, rewriter, storeOp.getBase(), newTypeOfSourceMemref); + Value linearizedOperand = linearizeOperand(loc, rewriter, storeOp.getBase(), + newTypeOfSourceMemref); rewriter.replaceOpWithNewOp( storeOp, storeOp.getValueToStore(), linearizedOperand, linearizedIndices); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir b/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir index dabdba15322c..96e9dd025b28 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir @@ -27,11 +27,23 @@ func.func @vector_load_store(%arg0: memref<2x3x4xi32>) -> vector<2x3xi32> { // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index // CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<24xi32> -// CHECK: %[[LOAD:.*]] = memref.load %[[ALLOC]][%[[C6]]] +// CHECK: %[[EXPAND_ALLOC:.*]] = memref.expand_shape %[[ALLOC]] +// CHECK{LITERAL}: [[0, 1, 2]] output_shape [2, 3, 4] +// CHECK-SAME: : memref<24xi32> into memref<2x3x4xi32> +// CHECK: %[[RESHAPE_ALLOC:.*]] = memref.reinterpret_cast %[[EXPAND_ALLOC]] to offset: [0], sizes: [24], strides: [1] +// CHECK-SAME: : memref<2x3x4xi32> to memref<24xi32> +// CHECK: %[[LOAD:.*]] = memref.load %[[RESHAPE_ALLOC]][%[[C6]]] // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<60xi32> -// CHECK: memref.store %[[LOAD]], %[[ALLOCA]][%[[C7]]] {nontemporal = true} : memref<60xi32> -// CHECK: memref.dealloc %[[ALLOC]] -// CHECK: memref.dealloc %[[ALLOCA]] +// CHECK: %[[EXPAND_ALLOCA:.*]] = memref.expand_shape %[[ALLOCA]] +// CHECK{LITERAL}: [[0, 1, 2]] output_shape [3, 4, 5] +// CHECK-SAME: : memref<60xi32> into memref<3x4x5xi32> +// CHECK: %[[RESHAPE_ALLOCA:.*]] = memref.reinterpret_cast %[[EXPAND_ALLOCA]] to offset: [0], sizes: [60], strides: [1] +// CHECK-SAME: : memref<3x4x5xi32> to memref<60xi32> +// CHECK: memref.store %[[LOAD]], %[[RESHAPE_ALLOCA]][%[[C7]]] {nontemporal = true} : memref<60xi32> +// CHECK: %[[RESHAPE_ALLOC_2:.*]] = memref.reinterpret_cast %[[EXPAND_ALLOC]] +// CHECK: memref.dealloc %[[RESHAPE_ALLOC_2]] +// CHECK: %[[RESHAPE_ALLOCA_2:.*]] = memref.reinterpret_cast %[[EXPAND_ALLOCA]] +// CHECK: memref.dealloc %[[RESHAPE_ALLOCA_2]] // CHECK: return %[[LOAD]] func.func @memref_load_store_alloc_dealloc() -> i32 { %c2 = arith.constant 2 : index From 7ede37376f561e805f3cdd308ff845b571d28a49 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 1 Jan 2025 12:36:14 +0000 Subject: [PATCH 05/10] Make improvement by also dealing with dynamic shapes Signed-off-by: Abhishek Varma --- .../Codegen/Common/LinearizeMemRefs.cpp | 197 ++++++++++++++---- .../Common/test/linearize_memrefs.mlir | 143 ++++++++++++- 2 files changed, 295 insertions(+), 45 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp index a1eb4348b432..15fbaf37adca 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp @@ -78,10 +78,66 @@ getLinearizedTypeFromSourceType(MemRefType currentTypeOfSourceMemref) { } static Value linearizeOperand(Location loc, PatternRewriter &rewriter, - Value operand, MemRefType linearizedType) { - return rewriter.create( - loc, linearizedType, operand, 0, linearizedType.getShape(), - ArrayRef({1})); + Value operand, MemRefType linearizedType, + OpFoldResult sizeRes = nullptr) { + // Fetch offset and strides of the old memref. + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(linearizedType, strides, offset))) + // TODO(avarma): Change function signature. + return nullptr; + + if (linearizedType.hasStaticShape()) { + return rewriter.create( + loc, linearizedType, operand, /*offset=*/offset, + linearizedType.getShape(), /*strides=*/strides); + } else { + assert(sizeRes && "expected a linear dynamic size to have been computed"); + SmallVector size = {sizeRes}; + OpFoldResult zeroOffset = rewriter.getIndexAttr(offset); + SmallVector strides; + strides.push_back(rewriter.getIndexAttr(1)); + return rewriter.create( + loc, linearizedType, operand, zeroOffset, size, strides); + } +} + +static SmallVector getDimValues(Location loc, PatternRewriter &rewriter, + MemRefType type, + ValueRange dynamicDims) { + SmallVector dims; + auto shape = type.getShape(); + int dynamicDimIndex = 0; + for (int i = 0; i < shape.size(); ++i) { + if (ShapedType::isDynamic(shape[i])) { + dims.push_back(dynamicDims[dynamicDimIndex++]); + } else { + dims.push_back(rewriter.create(loc, shape[i])); + } + } + return dims; +} + +static FailureOr> +getMixedOrigSize(Location loc, PatternRewriter &rewriter, Value sourceValue) { + MemRefType sourceType = llvm::cast(sourceValue.getType()); + Operation *sourceOp = sourceValue.getDefiningOp(); + if (auto allocOp = dyn_cast_if_present(sourceOp)) { + return getDimValues(loc, rewriter, sourceType, allocOp.getDynamicSizes()); + } else if (auto allocaOp = dyn_cast_if_present(sourceOp)) { + return getDimValues(loc, rewriter, sourceType, allocaOp.getDynamicSizes()); + } else { + if (sourceType.hasStaticShape()) { + SmallVector dims; + dims.reserve(sourceType.getRank()); + for (int64_t dim : sourceType.getShape()) { + dims.push_back(rewriter.create(loc, dim)); + } + return dims; + } else { + return failure(); + } + } } template @@ -105,20 +161,18 @@ struct LinearizeMemrefAlloc : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - int srcBits = - currentTypeOfSourceMemref.getElementType().getIntOrFloatBitWidth(); - int dstBits = srcBits; - OpFoldResult zero = rewriter.getIndexAttr(0); - - memref::LinearizedMemRefInfo linearizedMemRefInfo = - memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, srcBits, - dstBits, /*offset =*/zero, - allocOp.getMixedSizes()); SmallVector dynamicLinearizedSize; if (!newTypeOfSourceMemref.hasStaticShape()) { - dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp( - rewriter, loc, linearizedMemRefInfo.linearizedSize)); + SmallVector basis = getDimValues( + loc, rewriter, currentTypeOfSourceMemref, allocOp.getDynamicSizes()); + SmallVector multiIndices( + basis.size(), rewriter.create(loc, 0)); + multiIndices[0] = basis[0]; + Value linearizedSizes = rewriter.create( + loc, multiIndices, basis, true); + dynamicLinearizedSize.push_back(linearizedSizes); } + Value linearizedOp = rewriter.create( loc, newTypeOfSourceMemref, dynamicLinearizedSize, allocOp.getSymbolOperands(), allocOp.getAlignmentAttr()); @@ -148,12 +202,23 @@ struct LinearizeMemrefLoad : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; + FailureOr> basis = + getMixedOrigSize(loc, rewriter, loadOp.getMemref()); + if (failed(basis)) + return failure(); Value linearizedIndices = rewriter.create( - loc, loadOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); - Value linearizedOperand = linearizeOperand( - loc, rewriter, loadOp.getMemref(), newTypeOfSourceMemref); + loc, loadOp.getIndices(), *basis, true); + + SmallVector multiIndices( + (*basis).size(), rewriter.create(loc, 0)); + multiIndices[0] = (*basis)[0]; + Value linearizedSizes = rewriter.create( + loc, multiIndices, *basis, true); + Value linearizedOperand = + linearizeOperand(loc, rewriter, loadOp.getMemref(), + newTypeOfSourceMemref, linearizedSizes); Value linearizedLoad = rewriter.create( - loc, linearizedOperand, linearizedIndices); + loc, linearizedOperand, linearizedIndices, loadOp.getNontemporal()); rewriter.replaceOp(loadOp, {linearizedLoad}); return success(); @@ -176,15 +241,24 @@ struct LinearizeMemrefStore : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - auto elementType = storeOp.getMemRefType().getElementType(); - int srcBits = elementType.getIntOrFloatBitWidth(); + FailureOr> basis = + getMixedOrigSize(loc, rewriter, storeOp.getMemref()); + if (failed(basis)) + return failure(); Value linearizedIndices = rewriter.create( - loc, storeOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); - Value linearizedOperand = linearizeOperand( - loc, rewriter, storeOp.getMemref(), newTypeOfSourceMemref); + loc, storeOp.getIndices(), *basis, true); + SmallVector multiIndices( + (*basis).size(), rewriter.create(loc, 0)); + multiIndices[0] = (*basis)[0]; + Value linearizedSizes = rewriter.create( + loc, multiIndices, *basis, true); + Value linearizedOperand = + linearizeOperand(loc, rewriter, storeOp.getMemref(), + newTypeOfSourceMemref, linearizedSizes); + rewriter.replaceOpWithNewOp( storeOp, storeOp.getValueToStore(), linearizedOperand, - linearizedIndices, srcBits); + linearizedIndices, storeOp.getNontemporal()); return success(); } @@ -206,8 +280,18 @@ struct LinearizeMemrefDealloc : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - Value linearizedOperand = linearizeOperand( - loc, rewriter, deallocOp.getMemref(), newTypeOfSourceMemref); + FailureOr> basis = + getMixedOrigSize(loc, rewriter, deallocOp.getMemref()); + if (failed(basis)) + return failure(); + SmallVector multiIndices( + (*basis).size(), rewriter.create(loc, 0)); + multiIndices[0] = (*basis)[0]; + Value linearizedSizes = rewriter.create( + loc, multiIndices, *basis, true); + Value linearizedOperand = + linearizeOperand(loc, rewriter, deallocOp.getMemref(), + newTypeOfSourceMemref, linearizedSizes); rewriter.replaceOpWithNewOp(deallocOp, linearizedOperand); @@ -234,10 +318,27 @@ struct LinearizeMemrefCopy : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - Value linearizedSource = linearizeOperand(loc, rewriter, copyOp.getSource(), - newTypeOfSourceMemref); - Value linearizedTarget = linearizeOperand(loc, rewriter, copyOp.getTarget(), - newTypeOfSourceMemref); + FailureOr> basis = + getMixedOrigSize(loc, rewriter, copyOp.getSource()); + if (failed(basis)) + return failure(); + SmallVector multiIndices( + (*basis).size(), rewriter.create(loc, 0)); + multiIndices[0] = (*basis)[0]; + Value linearizedSizes = rewriter.create( + loc, multiIndices, *basis, true); + Value linearizedSource = + linearizeOperand(loc, rewriter, copyOp.getSource(), + newTypeOfSourceMemref, linearizedSizes); + basis = getMixedOrigSize(loc, rewriter, copyOp.getTarget()); + if (failed(basis)) + return failure(); + multiIndices[0] = (*basis)[0]; + linearizedSizes = rewriter.create( + loc, multiIndices, *basis, true); + Value linearizedTarget = + linearizeOperand(loc, rewriter, copyOp.getTarget(), + newTypeOfSourceMemref, linearizedSizes); rewriter.replaceOpWithNewOp(copyOp, linearizedSource, linearizedTarget); @@ -261,10 +362,21 @@ struct LinearizeVectorLoad : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; + FailureOr> basis = + getMixedOrigSize(loc, rewriter, loadOp.getBase()); + if (failed(basis)) + return failure(); Value linearizedIndices = rewriter.create( - loc, loadOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); - Value linearizedOperand = linearizeOperand(loc, rewriter, loadOp.getBase(), - newTypeOfSourceMemref); + loc, loadOp.getIndices(), *basis, true); + SmallVector multiIndices( + (*basis).size(), rewriter.create(loc, 0)); + multiIndices[0] = (*basis)[0]; + Value linearizedSizes = rewriter.create( + loc, multiIndices, *basis, true); + Value linearizedOperand = + linearizeOperand(loc, rewriter, loadOp.getBase(), newTypeOfSourceMemref, + linearizedSizes); + Value linearizedLoad = rewriter.create( loc, loadOp.getType(), linearizedOperand, linearizedIndices); @@ -289,10 +401,21 @@ struct LinearizeVectorStore : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; + FailureOr> basis = + getMixedOrigSize(loc, rewriter, storeOp.getBase()); + if (failed(basis)) + return failure(); Value linearizedIndices = rewriter.create( - loc, storeOp.getIndices(), currentTypeOfSourceMemref.getShape(), true); - Value linearizedOperand = linearizeOperand(loc, rewriter, storeOp.getBase(), - newTypeOfSourceMemref); + loc, storeOp.getIndices(), *basis, true); + SmallVector multiIndices( + (*basis).size(), rewriter.create(loc, 0)); + multiIndices[0] = (*basis)[0]; + Value linearizedSizes = rewriter.create( + loc, multiIndices, *basis, true); + Value linearizedOperand = + linearizeOperand(loc, rewriter, storeOp.getBase(), + newTypeOfSourceMemref, linearizedSizes); + rewriter.replaceOpWithNewOp( storeOp, storeOp.getValueToStore(), linearizedOperand, linearizedIndices); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir b/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir index 96e9dd025b28..b159e33d79e3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/linearize_memrefs.mlir @@ -1,6 +1,10 @@ // RUN: iree-opt -iree-linearize-memrefs -allow-unregistered-dialect %s | FileCheck %s -// CHECK-LABEL: @vector_load_store( +//-------------------------------------------------------------------------- +//---------------------------- VECTOR OPS ---------------------------------- +//-------------------------------------------------------------------------- + +// CHECK-LABEL: @vector_load_store_static( // CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xi32>) // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[ARG0]] to @@ -12,7 +16,7 @@ // CHECK-SAME: memref<2x3x4xi32> to memref<24xi32> // CHECK: vector.store %[[LOAD]], %[[CAST_2]][%[[C6]]] // CHECK: return %[[LOAD]] -func.func @vector_load_store(%arg0: memref<2x3x4xi32>) -> vector<2x3xi32> { +func.func @vector_load_store_static(%arg0: memref<2x3x4xi32>) -> vector<2x3xi32> { %c2 = arith.constant 2 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index @@ -21,14 +25,56 @@ func.func @vector_load_store(%arg0: memref<2x3x4xi32>) -> vector<2x3xi32> { return %1 : vector<2x3xi32> } +// CHECK-LABEL: @vector_load_store_dynamic( +// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index, %[[I0:.*]]: index, %[[I1:.*]]: index, %[[I2:.*]]: index) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[LINEAR_SIZE:.*]] = affine.linearize_index disjoint [%[[DIM0]], %[[C0]], %[[C0]] +// CHECK-SAME: by (%[[DIM0]], %[[DIM1]], %[[DIM2]]) : index +// CHECK: %[[ALLOC:.*]] = memref.alloca(%[[LINEAR_SIZE]]) : memref +// CHECK: %[[EXPAND_SHAPE:.*]] = memref.expand_shape %[[ALLOC]] +// CHECK{LITERAL}: [[0, 1, 2]] +// CHECK-SAME: output_shape [%[[DIM0]], %[[DIM1]], %[[DIM2]]] +// CHECK-SAME: : memref into memref +// CHECK: %[[LINEAR_INDEX:.*]] = affine.linearize_index disjoint [%[[I0]], %[[I1]], %[[I2]] +// CHECK-SAME: by (%[[DIM0]], %[[DIM1]], %[[DIM2]]) : index +// CHECK: %[[LINEAR_SIZE:.*]] = affine.linearize_index disjoint [%[[DIM0]], %[[C0]], %[[C0]] +// CHECK-SAME: by (%[[DIM0]], %[[DIM1]], %[[DIM2]]) : index +// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[EXPAND_SHAPE]] to +// CHECK-SAME: offset: [0], sizes: [%[[LINEAR_SIZE]]], strides: [1] : +// CHECK-SAME: memref to memref +// CHECK: %[[LOAD:.*]] = vector.load %[[CAST]][%[[LINEAR_INDEX]]] : memref, vector<2x3xi32> +// CHECK: %[[LINEAR_INDEX:.*]] = affine.linearize_index disjoint [%[[I0]], %[[I1]], %[[I2]] +// CHECK-SAME: by (%[[DIM0]], %[[DIM1]], %[[DIM2]]) : index +// CHECK: %[[LINEAR_SIZE:.*]] = affine.linearize_index disjoint [%[[DIM0]], %[[C0]], %[[C0]] +// CHECK-SAME: by (%[[DIM0]], %[[DIM1]], %[[DIM2]]) : index +// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[EXPAND_SHAPE]] to +// CHECK-SAME: offset: [0], sizes: [%[[LINEAR_SIZE]]], strides: [1] : +// CHECK-SAME: memref to memref +// CHECK: vector.store %[[LOAD]], %[[CAST]][%[[LINEAR_INDEX]]] : memref, vector<2x3xi32> +// CHECK: return %[[LOAD]] : vector<2x3xi32> +func.func @vector_load_store_dynamic(%dim0 : index, %dim1: index, %dim2: index, %i0: index, %i1: index, %i2: index) -> vector<2x3xi32> { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %alloc = memref.alloca(%dim0, %dim1, %dim2) : memref + %1 = vector.load %alloc[%i0, %i1, %i2] : memref, vector<2x3xi32> + vector.store %1, %alloc[%i0, %i1, %i2] : memref, vector<2x3xi32> + return %1 : vector<2x3xi32> +} + // ----- -// CHECK-LABEL: @memref_load_store_alloc_dealloc( +//-------------------------------------------------------------------------- +//---------------------------- MEMREF OPS ---------------------------------- +//-------------------------------------------------------------------------- + + +// CHECK-LABEL: @memref_load_store_alloc_dealloc_static( // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index // CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<24xi32> // CHECK: %[[EXPAND_ALLOC:.*]] = memref.expand_shape %[[ALLOC]] -// CHECK{LITERAL}: [[0, 1, 2]] output_shape [2, 3, 4] +// CHECK{LITERAL}: [[0, 1, 2]] output_shape [2, 3, 4] // CHECK-SAME: : memref<24xi32> into memref<2x3x4xi32> // CHECK: %[[RESHAPE_ALLOC:.*]] = memref.reinterpret_cast %[[EXPAND_ALLOC]] to offset: [0], sizes: [24], strides: [1] // CHECK-SAME: : memref<2x3x4xi32> to memref<24xi32> @@ -45,22 +91,67 @@ func.func @vector_load_store(%arg0: memref<2x3x4xi32>) -> vector<2x3xi32> { // CHECK: %[[RESHAPE_ALLOCA_2:.*]] = memref.reinterpret_cast %[[EXPAND_ALLOCA]] // CHECK: memref.dealloc %[[RESHAPE_ALLOCA_2]] // CHECK: return %[[LOAD]] -func.func @memref_load_store_alloc_dealloc() -> i32 { +func.func @memref_load_store_alloc_dealloc_static() -> i32 { %c2 = arith.constant 2 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %0 = memref.alloc() : memref<2x3x4xi32> %1 = memref.load %0[%c0, %c1, %c2] : memref<2x3x4xi32> %2 = memref.alloca() : memref<3x4x5xi32> - memref.store %1, %2[%c0, %c1, %c2] : memref<3x4x5xi32> + memref.store %1, %2[%c0, %c1, %c2] {nontemporal = true} : memref<3x4x5xi32> memref.dealloc %0 : memref<2x3x4xi32> memref.dealloc %2 : memref<3x4x5xi32> return %1 : i32 } +// CHECK-LABEL: @memref_load_store_alloc_dealloc_dynamic( +// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index, %[[I0:.*]]: index, %[[I1:.*]]: index, %[[I2:.*]]: index) +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[LINEAR_SIZE:.*]] = affine.linearize_index disjoint [%[[DIM0]], %[[C0]], %[[C0]] +// CHECK-SAME: by (%[[DIM0]], %[[DIM1]], %[[DIM2]]) : index +// CHECK: %[[ALLOC:.*]] = memref.alloca(%[[LINEAR_SIZE]]) : memref +// CHECK: %[[EXPAND_SHAPE:.*]] = memref.expand_shape %[[ALLOC]] +// CHECK{LITERAL}: [[0, 1, 2]] +// CHECK-SAME: output_shape [%[[DIM0]], %[[DIM1]], %[[DIM2]]] +// CHECK-SAME: : memref into memref +// CHECK: %[[LINEAR_INDEX:.*]] = affine.linearize_index disjoint [%[[I0]], %[[I1]], %[[I2]] +// CHECK-SAME: by (%[[DIM0]], %[[DIM1]], %[[DIM2]]) : index +// CHECK: %[[LINEAR_SIZE:.*]] = affine.linearize_index disjoint [%[[DIM0]], %[[C0]], %[[C0]] +// CHECK-SAME: by (%[[DIM0]], %[[DIM1]], %[[DIM2]]) : index +// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[EXPAND_SHAPE]] to +// CHECK-SAME: offset: [0], sizes: [%[[LINEAR_SIZE]]], strides: [1] : +// CHECK-SAME: memref to memref +// CHECK: %[[LOAD:.*]] = memref.load %[[CAST]][%[[LINEAR_INDEX]]] : memref +// CHECK: %[[LINEAR_INDEX:.*]] = affine.linearize_index disjoint [%[[C2]], %[[C1]], %[[C0]] +// CHECK-SAME: by (%[[DIM0]], %[[DIM1]], %[[DIM2]]) : index +// CHECK: %[[LINEAR_SIZE:.*]] = affine.linearize_index disjoint [%[[DIM0]], %[[C0]], %[[C0]] +// CHECK-SAME: by (%[[DIM0]], %[[DIM1]], %[[DIM2]]) : index +// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[EXPAND_SHAPE]] to +// CHECK-SAME: offset: [0], sizes: [%[[LINEAR_SIZE]]], strides: [1] : +// CHECK-SAME: memref to memref +// CHECK: memref.store %[[LOAD]], %[[CAST]][%[[LINEAR_INDEX]]] : memref +// CHECK: %[[LINEAR_SIZE:.*]] = affine.linearize_index disjoint [%[[DIM0]], %[[C0]], %[[C0]] +// CHECK-SAME: by (%[[DIM0]], %[[DIM1]], %[[DIM2]]) : index +// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[EXPAND_SHAPE]] to +// CHECK-SAME: offset: [0], sizes: [%[[LINEAR_SIZE]]], strides: [1] : +// CHECK-SAME: memref to memref +// CHECK: memref.dealloc %[[CAST]] : memref +// CHECK: return %[[LOAD]] : f32 +func.func @memref_load_store_alloc_dealloc_dynamic(%dim0 : index, %dim1: index, %dim2: index, %i0: index, %i1: index, %i2: index) -> f32 { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %alloc = memref.alloca(%dim0, %dim1, %dim2) : memref + %1 = memref.load %alloc[%i0, %i1, %i2] : memref + memref.store %1, %alloc[%c2, %c1, %c0] : memref + memref.dealloc %alloc : memref + return %1 : f32 +} // ----- -// CHECK-LABEL: @memref_copy( +// CHECK-LABEL: @memref_copy_static( // CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xi32>, // CHECK-SAME: %[[ARG1:.*]]: memref<2x3x4xi32>) // CHECK: %[[CAST_1:.*]] = memref.reinterpret_cast %[[ARG0]] to @@ -71,7 +162,43 @@ func.func @memref_load_store_alloc_dealloc() -> i32 { // CHECK-SAME: memref<2x3x4xi32> to memref<24xi32> // CHECK: memref.copy %[[CAST_1]], %[[CAST_2]] // CHECK: return -func.func @memref_copy(%arg0: memref<2x3x4xi32>, %arg1: memref<2x3x4xi32>) { +func.func @memref_copy_static(%arg0: memref<2x3x4xi32>, %arg1: memref<2x3x4xi32>) { memref.copy %arg0, %arg1 : memref<2x3x4xi32> to memref<2x3x4xi32> return } + +// CHECK-LABEL: @memref_copy_dynamic( +// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index, %[[DIM3:.*]]: index, %[[DIM4:.*]]: index, %[[DIM5:.*]]: index) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[LINEAR_SIZE:.*]] = affine.linearize_index disjoint [%[[DIM0]], %[[C0]], %[[C0]] +// CHECK-SAME: by (%[[DIM0]], %[[DIM1]], %[[DIM2]]) : index +// CHECK: %[[ALLOC:.*]] = memref.alloca(%[[LINEAR_SIZE]]) : memref +// CHECK: %[[EXPAND_SHAPE:.*]] = memref.expand_shape %[[ALLOC]] +// CHECK{LITERAL}: [[0, 1, 2]] +// CHECK-SAME: output_shape [%[[DIM0]], %[[DIM1]], %[[DIM2]]] +// CHECK-SAME: : memref into memref +// CHECK: %[[LINEAR_SIZE_1:.*]] = affine.linearize_index disjoint [%[[DIM3]], %[[C0]], %[[C0]] +// CHECK-SAME: by (%[[DIM3]], %[[DIM4]], %[[DIM5]]) : index +// CHECK: %[[ALLOC_1:.*]] = memref.alloca(%[[LINEAR_SIZE_1]]) : memref +// CHECK: %[[EXPAND_SHAPE_1:.*]] = memref.expand_shape %[[ALLOC]] +// CHECK{LITERAL}: [[0, 1, 2]] +// CHECK-SAME: output_shape [%[[DIM3]], %[[DIM4]], %[[DIM5]]] +// CHECK-SAME: : memref into memref +// CHECK: %[[LINEAR_SIZE:.*]] = affine.linearize_index disjoint [%[[DIM0]], %[[C0]], %[[C0]] +// CHECK-SAME: by (%[[DIM0]], %[[DIM1]], %[[DIM2]]) : index +// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[EXPAND_SHAPE]] to +// CHECK-SAME: offset: [0], sizes: [%[[LINEAR_SIZE]]], strides: [1] : +// CHECK-SAME: memref to memref +// CHECK: %[[LINEAR_SIZE_1:.*]] = affine.linearize_index disjoint [%[[DIM3]], %[[C0]], %[[C0]] +// CHECK-SAME: by (%[[DIM3]], %[[DIM4]], %[[DIM5]]) : index +// CHECK: %[[CAST_1:.*]] = memref.reinterpret_cast %[[EXPAND_SHAPE_1]] to +// CHECK-SAME: offset: [0], sizes: [%[[LINEAR_SIZE_1]]], strides: [1] : +// CHECK-SAME: memref to memref +// CHECK: memref.copy %[[CAST]], %[[CAST_1]] : memref to memref +// CHECK: return +func.func @memref_copy_dynamic(%dim0 : index, %dim1: index, %dim2: index, %dim3 : index, %dim4: index, %dim5: index) { + %alloc = memref.alloca(%dim0, %dim1, %dim2) : memref + %alloc1 = memref.alloca(%dim3, %dim4, %dim5) : memref + memref.copy %alloc, %alloc1 : memref to memref + return +} From b4247c643db9beea71eca0263e42e15be78826d2 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Fri, 17 Jan 2025 08:36:55 +0000 Subject: [PATCH 06/10] Address review comment : Value -> OpFoldResult for basis Signed-off-by: Abhishek Varma --- .../Codegen/Common/LinearizeMemRefs.cpp | 49 ++++++++++--------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp index 15fbaf37adca..9034447c7646 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp @@ -102,23 +102,25 @@ static Value linearizeOperand(Location loc, PatternRewriter &rewriter, } } -static SmallVector getDimValues(Location loc, PatternRewriter &rewriter, - MemRefType type, - ValueRange dynamicDims) { - SmallVector dims; +static SmallVector getDimValues(Location loc, + PatternRewriter &rewriter, + MemRefType type, + ValueRange dynamicDims) { + SmallVector dims; auto shape = type.getShape(); int dynamicDimIndex = 0; for (int i = 0; i < shape.size(); ++i) { if (ShapedType::isDynamic(shape[i])) { dims.push_back(dynamicDims[dynamicDimIndex++]); } else { - dims.push_back(rewriter.create(loc, shape[i])); + dims.push_back( + rewriter.create(loc, shape[i]).getResult()); } } return dims; } -static FailureOr> +static FailureOr> getMixedOrigSize(Location loc, PatternRewriter &rewriter, Value sourceValue) { MemRefType sourceType = llvm::cast(sourceValue.getType()); Operation *sourceOp = sourceValue.getDefiningOp(); @@ -128,10 +130,11 @@ getMixedOrigSize(Location loc, PatternRewriter &rewriter, Value sourceValue) { return getDimValues(loc, rewriter, sourceType, allocaOp.getDynamicSizes()); } else { if (sourceType.hasStaticShape()) { - SmallVector dims; + SmallVector dims; dims.reserve(sourceType.getRank()); for (int64_t dim : sourceType.getShape()) { - dims.push_back(rewriter.create(loc, dim)); + dims.push_back( + rewriter.create(loc, dim).getResult()); } return dims; } else { @@ -163,11 +166,11 @@ struct LinearizeMemrefAlloc : public OpRewritePattern { SmallVector dynamicLinearizedSize; if (!newTypeOfSourceMemref.hasStaticShape()) { - SmallVector basis = getDimValues( + SmallVector basis = getDimValues( loc, rewriter, currentTypeOfSourceMemref, allocOp.getDynamicSizes()); SmallVector multiIndices( basis.size(), rewriter.create(loc, 0)); - multiIndices[0] = basis[0]; + multiIndices[0] = llvm::dyn_cast_if_present(basis[0]); Value linearizedSizes = rewriter.create( loc, multiIndices, basis, true); dynamicLinearizedSize.push_back(linearizedSizes); @@ -202,7 +205,7 @@ struct LinearizeMemrefLoad : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - FailureOr> basis = + FailureOr> basis = getMixedOrigSize(loc, rewriter, loadOp.getMemref()); if (failed(basis)) return failure(); @@ -211,7 +214,7 @@ struct LinearizeMemrefLoad : public OpRewritePattern { SmallVector multiIndices( (*basis).size(), rewriter.create(loc, 0)); - multiIndices[0] = (*basis)[0]; + multiIndices[0] = llvm::dyn_cast_if_present((*basis)[0]); Value linearizedSizes = rewriter.create( loc, multiIndices, *basis, true); Value linearizedOperand = @@ -241,7 +244,7 @@ struct LinearizeMemrefStore : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - FailureOr> basis = + FailureOr> basis = getMixedOrigSize(loc, rewriter, storeOp.getMemref()); if (failed(basis)) return failure(); @@ -249,7 +252,7 @@ struct LinearizeMemrefStore : public OpRewritePattern { loc, storeOp.getIndices(), *basis, true); SmallVector multiIndices( (*basis).size(), rewriter.create(loc, 0)); - multiIndices[0] = (*basis)[0]; + multiIndices[0] = llvm::dyn_cast_if_present((*basis)[0]); Value linearizedSizes = rewriter.create( loc, multiIndices, *basis, true); Value linearizedOperand = @@ -280,13 +283,13 @@ struct LinearizeMemrefDealloc : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - FailureOr> basis = + FailureOr> basis = getMixedOrigSize(loc, rewriter, deallocOp.getMemref()); if (failed(basis)) return failure(); SmallVector multiIndices( (*basis).size(), rewriter.create(loc, 0)); - multiIndices[0] = (*basis)[0]; + multiIndices[0] = llvm::dyn_cast_if_present((*basis)[0]); Value linearizedSizes = rewriter.create( loc, multiIndices, *basis, true); Value linearizedOperand = @@ -318,13 +321,13 @@ struct LinearizeMemrefCopy : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - FailureOr> basis = + FailureOr> basis = getMixedOrigSize(loc, rewriter, copyOp.getSource()); if (failed(basis)) return failure(); SmallVector multiIndices( (*basis).size(), rewriter.create(loc, 0)); - multiIndices[0] = (*basis)[0]; + multiIndices[0] = llvm::dyn_cast_if_present((*basis)[0]); Value linearizedSizes = rewriter.create( loc, multiIndices, *basis, true); Value linearizedSource = @@ -333,7 +336,7 @@ struct LinearizeMemrefCopy : public OpRewritePattern { basis = getMixedOrigSize(loc, rewriter, copyOp.getTarget()); if (failed(basis)) return failure(); - multiIndices[0] = (*basis)[0]; + multiIndices[0] = llvm::dyn_cast_if_present(((*basis)[0])); linearizedSizes = rewriter.create( loc, multiIndices, *basis, true); Value linearizedTarget = @@ -362,7 +365,7 @@ struct LinearizeVectorLoad : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - FailureOr> basis = + FailureOr> basis = getMixedOrigSize(loc, rewriter, loadOp.getBase()); if (failed(basis)) return failure(); @@ -370,7 +373,7 @@ struct LinearizeVectorLoad : public OpRewritePattern { loc, loadOp.getIndices(), *basis, true); SmallVector multiIndices( (*basis).size(), rewriter.create(loc, 0)); - multiIndices[0] = (*basis)[0]; + multiIndices[0] = llvm::dyn_cast_if_present((*basis)[0]); Value linearizedSizes = rewriter.create( loc, multiIndices, *basis, true); Value linearizedOperand = @@ -401,7 +404,7 @@ struct LinearizeVectorStore : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - FailureOr> basis = + FailureOr> basis = getMixedOrigSize(loc, rewriter, storeOp.getBase()); if (failed(basis)) return failure(); @@ -409,7 +412,7 @@ struct LinearizeVectorStore : public OpRewritePattern { loc, storeOp.getIndices(), *basis, true); SmallVector multiIndices( (*basis).size(), rewriter.create(loc, 0)); - multiIndices[0] = (*basis)[0]; + multiIndices[0] = llvm::dyn_cast_if_present((*basis)[0]); Value linearizedSizes = rewriter.create( loc, multiIndices, *basis, true); Value linearizedOperand = From e092b02447933d88bbd8150c7c9d67915fca4d66 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Fri, 17 Jan 2025 17:24:18 +0000 Subject: [PATCH 07/10] Use applyPatternsGreedily Signed-off-by: Abhishek Varma --- compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp index 9034447c7646..a16f30b38108 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp @@ -450,7 +450,7 @@ void LinearizeMemRefs::runOnOperation() { patterns.add(context); patterns.add(context); - (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); + (void)applyPatternsGreedily(moduleOp, std::move(patterns)); return; } From 8f3edf15731f20b3f172f9d6e05e006db52bef56 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Mon, 3 Feb 2025 09:54:36 +0000 Subject: [PATCH 08/10] Review comments by Quinn v2.0 Signed-off-by: Abhishek Varma --- .../Codegen/Common/LinearizeMemRefs.cpp | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp index a16f30b38108..ebfbdba6a946 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp @@ -1,4 +1,4 @@ -// Copyright 2024 The IREE Authors +// Copyright 2025 The IREE Authors // // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -38,7 +38,7 @@ static SmallVector getLinearizedShape(MemRefType type) { return {}; int64_t linearizedShape = 1; - for (auto shape : type.getShape()) { + for (int64_t shape : type.getShape()) { if (shape == ShapedType::kDynamic) return {ShapedType::kDynamic}; linearizedShape *= shape; @@ -54,7 +54,7 @@ static FailureOr linearizeType(MemRefType memrefType) { int64_t offset; if (failed(getStridesAndOffset(memrefType, strides, offset))) return failure(); - if (!strides.empty() && strides.back() != 1) + if (strides.empty()) return failure(); // Form layout for the linearized memref. StridedLayoutAttr layoutAttr; @@ -83,9 +83,10 @@ static Value linearizeOperand(Location loc, PatternRewriter &rewriter, // Fetch offset and strides of the old memref. SmallVector strides; int64_t offset; - if (failed(getStridesAndOffset(linearizedType, strides, offset))) + if (failed(getStridesAndOffset(linearizedType, strides, offset))) { // TODO(avarma): Change function signature. return nullptr; + } if (linearizedType.hasStaticShape()) { return rewriter.create( @@ -94,11 +95,11 @@ static Value linearizeOperand(Location loc, PatternRewriter &rewriter, } else { assert(sizeRes && "expected a linear dynamic size to have been computed"); SmallVector size = {sizeRes}; - OpFoldResult zeroOffset = rewriter.getIndexAttr(offset); + OpFoldResult staticOffset = rewriter.getIndexAttr(offset); SmallVector strides; strides.push_back(rewriter.getIndexAttr(1)); return rewriter.create( - loc, linearizedType, operand, zeroOffset, size, strides); + loc, linearizedType, operand, staticOffset, size, strides); } } @@ -179,7 +180,7 @@ struct LinearizeMemrefAlloc : public OpRewritePattern { Value linearizedOp = rewriter.create( loc, newTypeOfSourceMemref, dynamicLinearizedSize, allocOp.getSymbolOperands(), allocOp.getAlignmentAttr()); - SmallVector indices(currentTypeOfSourceMemref.getRank()); + SmallVector indices(currentTypeOfSourceMemref.getRank(), 0); std::iota(indices.begin(), indices.end(), 0); Value delinearizedOp = rewriter.create( loc, currentTypeOfSourceMemref, linearizedOp, ArrayRef({indices}), @@ -220,10 +221,9 @@ struct LinearizeMemrefLoad : public OpRewritePattern { Value linearizedOperand = linearizeOperand(loc, rewriter, loadOp.getMemref(), newTypeOfSourceMemref, linearizedSizes); - Value linearizedLoad = rewriter.create( - loc, linearizedOperand, linearizedIndices, loadOp.getNontemporal()); + rewriter.replaceOpWithNewOp( + loadOp, linearizedOperand, linearizedIndices, loadOp.getNontemporal()); - rewriter.replaceOp(loadOp, {linearizedLoad}); return success(); } }; @@ -336,7 +336,7 @@ struct LinearizeMemrefCopy : public OpRewritePattern { basis = getMixedOrigSize(loc, rewriter, copyOp.getTarget()); if (failed(basis)) return failure(); - multiIndices[0] = llvm::dyn_cast_if_present(((*basis)[0])); + multiIndices[0] = llvm::dyn_cast_if_present((*basis)[0]); linearizedSizes = rewriter.create( loc, multiIndices, *basis, true); Value linearizedTarget = @@ -380,10 +380,9 @@ struct LinearizeVectorLoad : public OpRewritePattern { linearizeOperand(loc, rewriter, loadOp.getBase(), newTypeOfSourceMemref, linearizedSizes); - Value linearizedLoad = rewriter.create( - loc, loadOp.getType(), linearizedOperand, linearizedIndices); + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), linearizedOperand, linearizedIndices); - rewriter.replaceOp(loadOp, {linearizedLoad}); return success(); } }; From 1f2683a768dca3e8e51fa9e575effef9347b013b Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Mon, 3 Feb 2025 10:13:03 +0000 Subject: [PATCH 09/10] Update getStridesAndOffset Signed-off-by: Abhishek Varma --- .../src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp index ebfbdba6a946..ce9934ea6e01 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp @@ -52,7 +52,7 @@ static FailureOr linearizeType(MemRefType memrefType) { // Fetch offset and strides of the old memref. SmallVector strides; int64_t offset; - if (failed(getStridesAndOffset(memrefType, strides, offset))) + if (failed(memrefType.getStridesAndOffset(strides, offset))) return failure(); if (strides.empty()) return failure(); @@ -83,7 +83,7 @@ static Value linearizeOperand(Location loc, PatternRewriter &rewriter, // Fetch offset and strides of the old memref. SmallVector strides; int64_t offset; - if (failed(getStridesAndOffset(linearizedType, strides, offset))) { + if (failed(linearizedType.getStridesAndOffset(strides, offset))) { // TODO(avarma): Change function signature. return nullptr; } From 67a34399aafdcab8c758493e8be5bb88a9878e28 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Fri, 7 Feb 2025 08:16:50 +0000 Subject: [PATCH 10/10] Refactor + use getMixedSizes Signed-off-by: Abhishek Varma --- .../Codegen/Common/LinearizeMemRefs.cpp | 193 +++++++----------- 1 file changed, 76 insertions(+), 117 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp index ce9934ea6e01..6351fab36b24 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LinearizeMemRefs.cpp @@ -103,45 +103,32 @@ static Value linearizeOperand(Location loc, PatternRewriter &rewriter, } } -static SmallVector getDimValues(Location loc, - PatternRewriter &rewriter, - MemRefType type, - ValueRange dynamicDims) { - SmallVector dims; - auto shape = type.getShape(); - int dynamicDimIndex = 0; - for (int i = 0; i < shape.size(); ++i) { - if (ShapedType::isDynamic(shape[i])) { - dims.push_back(dynamicDims[dynamicDimIndex++]); - } else { - dims.push_back( - rewriter.create(loc, shape[i]).getResult()); - } - } - return dims; -} - -static FailureOr> -getMixedOrigSize(Location loc, PatternRewriter &rewriter, Value sourceValue) { - MemRefType sourceType = llvm::cast(sourceValue.getType()); - Operation *sourceOp = sourceValue.getDefiningOp(); - if (auto allocOp = dyn_cast_if_present(sourceOp)) { - return getDimValues(loc, rewriter, sourceType, allocOp.getDynamicSizes()); - } else if (auto allocaOp = dyn_cast_if_present(sourceOp)) { - return getDimValues(loc, rewriter, sourceType, allocaOp.getDynamicSizes()); - } else { - if (sourceType.hasStaticShape()) { - SmallVector dims; - dims.reserve(sourceType.getRank()); - for (int64_t dim : sourceType.getShape()) { - dims.push_back( - rewriter.create(loc, dim).getResult()); - } - return dims; +/// Utility function used to linearize indices/sizes. +static FailureOr +createLinearizeIndexOp(PatternRewriter &rewriter, Location loc, Value memref, + SmallVector multiIndices = {}) { + // Step 1. Form basis. + FailureOr> basis = + memref::getMixedSizes(rewriter, loc, memref); + if (failed(basis)) + return failure(); + // Step 2. Form multiIndices if not already present. + if (multiIndices.empty()) { + SmallVector multiIndicesNew( + (*basis).size(), rewriter.create(loc, 0)); + if (auto attr = llvm::dyn_cast_if_present((*basis)[0])) { + auto intVal = cast(attr).getInt(); + multiIndicesNew[0] = + rewriter.create(loc, intVal).getResult(); } else { - return failure(); + multiIndicesNew[0] = llvm::dyn_cast_if_present((*basis)[0]); } + multiIndices = multiIndicesNew; } + // Step 3. Form linearizedIndexOp. + Value linearizedSizes = rewriter.create( + loc, multiIndices, *basis, true); + return linearizedSizes; } template @@ -167,14 +154,11 @@ struct LinearizeMemrefAlloc : public OpRewritePattern { SmallVector dynamicLinearizedSize; if (!newTypeOfSourceMemref.hasStaticShape()) { - SmallVector basis = getDimValues( - loc, rewriter, currentTypeOfSourceMemref, allocOp.getDynamicSizes()); - SmallVector multiIndices( - basis.size(), rewriter.create(loc, 0)); - multiIndices[0] = llvm::dyn_cast_if_present(basis[0]); - Value linearizedSizes = rewriter.create( - loc, multiIndices, basis, true); - dynamicLinearizedSize.push_back(linearizedSizes); + FailureOr linearizedSizes = + createLinearizeIndexOp(rewriter, loc, allocOp.getResult()); + if (failed(linearizedSizes)) + return failure(); + dynamicLinearizedSize.push_back(*linearizedSizes); } Value linearizedOp = rewriter.create( @@ -206,23 +190,19 @@ struct LinearizeMemrefLoad : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - FailureOr> basis = - getMixedOrigSize(loc, rewriter, loadOp.getMemref()); - if (failed(basis)) + FailureOr linearizedIndices = createLinearizeIndexOp( + rewriter, loc, loadOp.getMemref(), loadOp.getIndices()); + if (failed(linearizedIndices)) + return failure(); + FailureOr linearizedSizes = + createLinearizeIndexOp(rewriter, loc, loadOp.getMemref()); + if (failed(linearizedSizes)) return failure(); - Value linearizedIndices = rewriter.create( - loc, loadOp.getIndices(), *basis, true); - - SmallVector multiIndices( - (*basis).size(), rewriter.create(loc, 0)); - multiIndices[0] = llvm::dyn_cast_if_present((*basis)[0]); - Value linearizedSizes = rewriter.create( - loc, multiIndices, *basis, true); Value linearizedOperand = linearizeOperand(loc, rewriter, loadOp.getMemref(), - newTypeOfSourceMemref, linearizedSizes); + newTypeOfSourceMemref, *linearizedSizes); rewriter.replaceOpWithNewOp( - loadOp, linearizedOperand, linearizedIndices, loadOp.getNontemporal()); + loadOp, linearizedOperand, *linearizedIndices, loadOp.getNontemporal()); return success(); } @@ -244,24 +224,21 @@ struct LinearizeMemrefStore : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - FailureOr> basis = - getMixedOrigSize(loc, rewriter, storeOp.getMemref()); - if (failed(basis)) + FailureOr linearizedIndices = createLinearizeIndexOp( + rewriter, loc, storeOp.getMemref(), storeOp.getIndices()); + if (failed(linearizedIndices)) + return failure(); + FailureOr linearizedSizes = + createLinearizeIndexOp(rewriter, loc, storeOp.getMemref()); + if (failed(linearizedSizes)) return failure(); - Value linearizedIndices = rewriter.create( - loc, storeOp.getIndices(), *basis, true); - SmallVector multiIndices( - (*basis).size(), rewriter.create(loc, 0)); - multiIndices[0] = llvm::dyn_cast_if_present((*basis)[0]); - Value linearizedSizes = rewriter.create( - loc, multiIndices, *basis, true); Value linearizedOperand = linearizeOperand(loc, rewriter, storeOp.getMemref(), - newTypeOfSourceMemref, linearizedSizes); + newTypeOfSourceMemref, *linearizedSizes); rewriter.replaceOpWithNewOp( storeOp, storeOp.getValueToStore(), linearizedOperand, - linearizedIndices, storeOp.getNontemporal()); + *linearizedIndices, storeOp.getNontemporal()); return success(); } @@ -283,18 +260,13 @@ struct LinearizeMemrefDealloc : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - FailureOr> basis = - getMixedOrigSize(loc, rewriter, deallocOp.getMemref()); - if (failed(basis)) + FailureOr linearizedSizes = + createLinearizeIndexOp(rewriter, loc, deallocOp.getMemref()); + if (failed(linearizedSizes)) return failure(); - SmallVector multiIndices( - (*basis).size(), rewriter.create(loc, 0)); - multiIndices[0] = llvm::dyn_cast_if_present((*basis)[0]); - Value linearizedSizes = rewriter.create( - loc, multiIndices, *basis, true); Value linearizedOperand = linearizeOperand(loc, rewriter, deallocOp.getMemref(), - newTypeOfSourceMemref, linearizedSizes); + newTypeOfSourceMemref, *linearizedSizes); rewriter.replaceOpWithNewOp(deallocOp, linearizedOperand); @@ -321,27 +293,19 @@ struct LinearizeMemrefCopy : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - FailureOr> basis = - getMixedOrigSize(loc, rewriter, copyOp.getSource()); - if (failed(basis)) + FailureOr linearizedSizes = + createLinearizeIndexOp(rewriter, loc, copyOp.getSource()); + if (failed(linearizedSizes)) return failure(); - SmallVector multiIndices( - (*basis).size(), rewriter.create(loc, 0)); - multiIndices[0] = llvm::dyn_cast_if_present((*basis)[0]); - Value linearizedSizes = rewriter.create( - loc, multiIndices, *basis, true); Value linearizedSource = linearizeOperand(loc, rewriter, copyOp.getSource(), - newTypeOfSourceMemref, linearizedSizes); - basis = getMixedOrigSize(loc, rewriter, copyOp.getTarget()); - if (failed(basis)) + newTypeOfSourceMemref, *linearizedSizes); + linearizedSizes = createLinearizeIndexOp(rewriter, loc, copyOp.getTarget()); + if (failed(linearizedSizes)) return failure(); - multiIndices[0] = llvm::dyn_cast_if_present((*basis)[0]); - linearizedSizes = rewriter.create( - loc, multiIndices, *basis, true); Value linearizedTarget = linearizeOperand(loc, rewriter, copyOp.getTarget(), - newTypeOfSourceMemref, linearizedSizes); + newTypeOfSourceMemref, *linearizedSizes); rewriter.replaceOpWithNewOp(copyOp, linearizedSource, linearizedTarget); @@ -365,23 +329,20 @@ struct LinearizeVectorLoad : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - FailureOr> basis = - getMixedOrigSize(loc, rewriter, loadOp.getBase()); - if (failed(basis)) + FailureOr linearizedIndices = createLinearizeIndexOp( + rewriter, loc, loadOp.getBase(), loadOp.getIndices()); + if (failed(linearizedIndices)) + return failure(); + FailureOr linearizedSizes = + createLinearizeIndexOp(rewriter, loc, loadOp.getBase()); + if (failed(linearizedSizes)) return failure(); - Value linearizedIndices = rewriter.create( - loc, loadOp.getIndices(), *basis, true); - SmallVector multiIndices( - (*basis).size(), rewriter.create(loc, 0)); - multiIndices[0] = llvm::dyn_cast_if_present((*basis)[0]); - Value linearizedSizes = rewriter.create( - loc, multiIndices, *basis, true); Value linearizedOperand = linearizeOperand(loc, rewriter, loadOp.getBase(), newTypeOfSourceMemref, - linearizedSizes); + *linearizedSizes); rewriter.replaceOpWithNewOp( - loadOp, loadOp.getType(), linearizedOperand, linearizedIndices); + loadOp, loadOp.getType(), linearizedOperand, *linearizedIndices); return success(); } @@ -403,28 +364,26 @@ struct LinearizeVectorStore : public OpRewritePattern { return failure(); MemRefType newTypeOfSourceMemref = *maybeNewTypeOfSourceMemref; - FailureOr> basis = - getMixedOrigSize(loc, rewriter, storeOp.getBase()); - if (failed(basis)) + FailureOr linearizedIndices = createLinearizeIndexOp( + rewriter, loc, storeOp.getBase(), storeOp.getIndices()); + if (failed(linearizedIndices)) + return failure(); + FailureOr linearizedSizes = + createLinearizeIndexOp(rewriter, loc, storeOp.getBase()); + if (failed(linearizedSizes)) return failure(); - Value linearizedIndices = rewriter.create( - loc, storeOp.getIndices(), *basis, true); - SmallVector multiIndices( - (*basis).size(), rewriter.create(loc, 0)); - multiIndices[0] = llvm::dyn_cast_if_present((*basis)[0]); - Value linearizedSizes = rewriter.create( - loc, multiIndices, *basis, true); Value linearizedOperand = linearizeOperand(loc, rewriter, storeOp.getBase(), - newTypeOfSourceMemref, linearizedSizes); + newTypeOfSourceMemref, *linearizedSizes); rewriter.replaceOpWithNewOp( storeOp, storeOp.getValueToStore(), linearizedOperand, - linearizedIndices); + *linearizedIndices); return success(); } }; + //===----------------------------------------------------------------------===// // Pass //===----------------------------------------------------------------------===//