Skip to content

Commit

Permalink
Bubble expand shapes through AttentionOps (iree-org#18074)
Browse files Browse the repository at this point in the history
The idea was to incrementally land these changes (see
iree-org#18030 and
iree-org#18017), but it appears that they
are dependent on each other to not cause regressions in sdxl fp16. I
will give a summary of what the commits here do.

## 1. Make all input maps identity (originally on PR
iree-org#18017)
First two commits
> Change ElementwiseOpInterchangePattern from making output maps
identity to making the input maps identity.
iree-org#18006

## 2. Add `LinalgFusionOpInterface` to attention op (originally on PR
iree-org#18030)
Next two commits after merge
>- Added interface to attention op so that producers/consumers can get
fused into the same dispatch.
>- Cleaned up interface method naming by changing `getIndexingMaps` to
`getIndexingMapsArray` to better match linalg. `getIndexingMaps` also
already conflicted with a method that attention already has.
>- Added indexing check to attention verifier and corresponding test.

## 3. Most recent commit on this PR
This was going to be the only commit of this PR and is the one that
needs the most review. It creates
LinalgExt::populateFoldReshapeOpsByExpansionPattern that emulates
linalg::populateFoldReshapeOpsByExpansionPattern but for AttentionOp. A
bit more info here iree-org#17673.

---------

Signed-off-by: Ian Wood <[email protected]>
Signed-off-by: saienduri <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Co-authored-by: MaheshRavishankar <[email protected]>
Co-authored-by: saienduri <[email protected]>
  • Loading branch information
3 people authored Aug 21, 2024
1 parent 6a92fb7 commit 8dd1db3
Show file tree
Hide file tree
Showing 21 changed files with 778 additions and 48 deletions.
20 changes: 10 additions & 10 deletions .github/workflows/pkgci_regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,13 @@ jobs:
run: |
source ${VENV_DIR}/bin/activate
pytest ./experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py \
--goldentime-rocm-e2e-ms 1450.0 \
--goldentime-rocm-unet-ms 370.0 \
--goldentime-rocm-e2e-ms 1616.0 \
--goldentime-rocm-unet-ms 419.0 \
--goldentime-rocm-clip-ms 18.5 \
--goldentime-rocm-vae-ms 315.0 \
--goldendispatch-rocm-unet 1691 \
--goldentime-rocm-vae-ms 337.0 \
--goldendispatch-rocm-unet 1551 \
--goldendispatch-rocm-clip 1225 \
--goldendispatch-rocm-vae 248 \
--goldendispatch-rocm-vae 247 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
Expand All @@ -354,13 +354,13 @@ jobs:
run: |
source ${VENV_DIR}/bin/activate
pytest ./experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py \
--goldentime-rocm-e2e-ms 325.0 \
--goldentime-rocm-unet-ms 77.0 \
--goldentime-rocm-e2e-ms 372.0 \
--goldentime-rocm-unet-ms 95.0 \
--goldentime-rocm-clip-ms 15.5 \
--goldentime-rocm-vae-ms 74.0 \
--goldendispatch-rocm-unet 1691 \
--goldentime-rocm-vae-ms 80.0 \
--goldendispatch-rocm-unet 1551 \
--goldendispatch-rocm-clip 1225 \
--goldendispatch-rocm-vae 248 \
--goldendispatch-rocm-vae 247 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -82,6 +84,9 @@ void BubbleUpExpandShapesPass::runOnOperation() {
};
linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns,
bubbleUpExpansionControlFn);
LinalgExt::populateFoldReshapeOpsByExpansionPatterns(
bubbleExpandShapePatterns, bubbleUpExpansionControlFn);

// Add patterns to do some additional cleanup (on top of canonicalizations
// that can be done later) of reshape ops.
tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,11 @@ isFusableWithProducer(OpOperand &operand,
return true;
}

// Don't fuse attention with it's producer
if (isa<LinalgExt::AttentionOp>(consumer)) {
return false;
}

if (isPackLikeOp(consumer)) {
return TypeSwitch<Operation *, bool>(producer)
.Case<tensor::PadOp>([&](auto padOp) { return true; })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
Expand All @@ -41,23 +42,36 @@ namespace {
// ElementwiseOpInterchangePattern
//===----------------------------------------------------------------------===//

// If possible, interchange indexing maps to make input maps all identity.
struct ElementwiseOpInterchangePattern
: public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1)
if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1 ||
genericOp.getNumDpsInputs() == 0)
return failure();

AffineMap indexingMap = genericOp.getIndexingMapsArray().back();
if (indexingMap.isIdentity())
// All input maps must be equal and non-identity. All maps, including
// output, must be be permutations. Permutation maps are checked by
// isElementwise but may be removed.
AffineMap inputMap = genericOp.getIndexingMapsArray().front();
auto *initOperand = genericOp.getDpsInitOperand(0);
if (inputMap.isIdentity() || !inputMap.isPermutation() ||
!genericOp.getMatchingIndexingMap(initOperand).isPermutation()) {
return failure();
}
for (auto *operand : genericOp.getDpsInputOperands()) {
if (genericOp.getMatchingIndexingMap(operand) != inputMap) {
return failure();
}
}

ArrayRef<AffineExpr> exprs = indexingMap.getResults();
// Make all inputs identity.
ArrayRef<AffineExpr> exprs = inputMap.getResults();
auto perm = llvm::map_to_vector(exprs, [](AffineExpr e) -> unsigned {
return cast<AffineDimExpr>(e).getPosition();
});

return linalg::interchangeGenericOp(rewriter, genericOp, perm);
}
};
Expand Down Expand Up @@ -210,6 +224,7 @@ struct FusionPreprocessingPass
// operand shapes.
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ iree_lit_test_suite(
srcs = enforce_glob(
[
"annotate_dispatches.mlir",
"attention_fuse_by_expansion.mlir",
"capture_dispatch_dynamic_dims.mlir",
"capture_scf_for_dynamic_dims.mlir",
"cleanup_tensor_shapes.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_lit_test_suite(
lit
SRCS
"annotate_dispatches.mlir"
"attention_fuse_by_expansion.mlir"
"capture_dispatch_dynamic_dims.mlir"
"capture_scf_for_dynamic_dims.mlir"
"cleanup_tensor_shapes.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-flow-bubble-up-expand-shapes, canonicalize, cse, canonicalize))" %s | FileCheck %s

#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>

util.func public @attention_static(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16) -> tensor<2x10x4096x64xf16> {
%0 = tensor.empty() : tensor<20x4096x64xf16>
%1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
%expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16>
util.return %expanded : tensor<2x10x4096x64xf16>
}

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
// CHECK: func public @attention_static(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16>
// CHECK-SAME: %[[ARG3:.+]]: f16)
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x4096x64xf16>
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16]
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16]
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64]
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] :
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK: util.return %[[ATTENTION]]

// -----

#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>

util.func public @attention_expand_all(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16) -> tensor<2x10x2048x2x2x32xf16> {
%0 = tensor.empty() : tensor<20x4096x64xf16>
%1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
%expanded = tensor.expand_shape %1 [[0, 1], [2, 3], [4, 5]] output_shape [2, 10, 2048, 2, 2, 32] : tensor<20x4096x64xf16> into tensor<2x10x2048x2x2x32xf16>
util.return %expanded : tensor<2x10x2048x2x2x32xf16>
}

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d4)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d7)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d6, d7)>
// CHECK: func public @attention_expand_all(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16>
// CHECK-SAME: %[[ARG3:.+]]: f16)
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x2048x2x2x32xf16>
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3], [4]] output_shape [2, 10, 2048, 2, 16] : tensor<20x4096x16xf16> into tensor<2x10x2048x2x16xf16>
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3]] output_shape [2, 10, 1024, 16] : tensor<20x1024x16xf16> into tensor<2x10x1024x16xf16>
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [2, 10, 1024, 2, 32] : tensor<20x1024x64xf16> into tensor<2x10x1024x2x32xf16>
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] :
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK: util.return %[[ATTENTION]]

// -----


#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>

util.func public @attention_dynamic(%arg0: tensor<?x?x?xf16>, %arg1: tensor<?x?x?xf16>, %arg2: tensor<?x?x?xf16>, %arg3: f16) -> tensor<2x?x?x?xf16> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf16>
%d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf16>
%d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf16>
%d3 = tensor.dim %arg1, %c1 : tensor<?x?x?xf16>
%d4 = tensor.dim %arg2, %c2 : tensor<?x?x?xf16>
%0 = tensor.empty(%d0, %d1, %d4) : tensor<?x?x?xf16>
%1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16) outs(%0 : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
%split = arith.divsi %d0, %c2 : index
%expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4]
: tensor<?x?x?xf16> into tensor<2x?x?x?xf16>
util.return %expanded : tensor<2x?x?x?xf16>
}

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
// CHECK: func public @attention_dynamic(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
// CHECK-SAME: %[[ARG3:.+]]: f16)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]]
// CHECK-DAG: %[[SPLIT0:.+]] = arith.divui %[[D0]]
// CHECK-DAG: %[[VAL:.+]] = affine.apply #[[MAP4]]()[%[[D0]]]
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[VAL]], %[[D1]], %[[D4]]) : tensor<2x?x?x?xf16>
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D1]], %[[D2]]]
// CHECK-DAG: %[[D5:.+]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[D6:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[D7:.+]] = tensor.dim %[[ARG1]], %[[C2]]
// CHECK-DAG: %[[SPLIT1:.+]] = arith.divui %[[D5]], %[[C2]]
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT1]], %[[D6]], %[[D7]]]
// CHECK-DAG: %[[D8:.+]] = tensor.dim %[[ARG2]], %[[C0]]
// CHECK-DAG: %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C1]]
// CHECK-DAG: %[[SPLIT2:.+]] = arith.divui %[[D8]], %[[C2]]
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT2]], %[[D9]], %[[D4]]]
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] :
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK: util.return %[[ATTENTION]]
Loading

0 comments on commit 8dd1db3

Please sign in to comment.