From 8dd1db34174c603e6a471dd7523ead112f721245 Mon Sep 17 00:00:00 2001 From: Ian Wood <75152913+IanWood1@users.noreply.github.com> Date: Wed, 21 Aug 2024 09:24:13 -0700 Subject: [PATCH] Bubble expand shapes through `AttentionOp`s (#18074) The idea was to incrementally land these changes (see https://github.com/iree-org/iree/pull/18030 and https://github.com/iree-org/iree/pull/18017), but it appears that they are dependent on each other to not cause regressions in sdxl fp16. I will give a summary of what the commits here do. ## 1. Make all input maps identity (originally on PR https://github.com/iree-org/iree/pull/18017) First two commits > Change ElementwiseOpInterchangePattern from making output maps identity to making the input maps identity. https://github.com/iree-org/iree/issues/18006 ## 2. Add `LinalgFusionOpInterface` to attention op (originally on PR https://github.com/iree-org/iree/pull/18030) Next two commits after merge >- Added interface to attention op so that producers/consumers can get fused into the same dispatch. >- Cleaned up interface method naming by changing `getIndexingMaps` to `getIndexingMapsArray` to better match linalg. `getIndexingMaps` also already conflicted with a method that attention already has. >- Added indexing check to attention verifier and corresponding test. ## 3. Most recent commit on this PR This was going to be the only commit of this PR and is the one that needs the most review. It creates LinalgExt::populateFoldReshapeOpsByExpansionPattern that emulates linalg::populateFoldReshapeOpsByExpansionPattern but for AttentionOp. A bit more info here https://github.com/iree-org/iree/issues/17673. --------- Signed-off-by: Ian Wood Signed-off-by: saienduri Signed-off-by: MaheshRavishankar Co-authored-by: MaheshRavishankar Co-authored-by: saienduri --- .github/workflows/pkgci_regression_test.yml | 20 +- .../Flow/Transforms/BubbleUpExpandShapes.cpp | 5 + .../Flow/Transforms/FormDispatchRegions.cpp | 5 + .../Flow/Transforms/FusionPreprocessing.cpp | 25 +- .../Dialect/Flow/Transforms/test/BUILD.bazel | 1 + .../Flow/Transforms/test/CMakeLists.txt | 1 + .../test/attention_fuse_by_expansion.mlir | 127 +++++++ .../test/dispatch_linalg_ext_fusion.mlir | 83 ++++- .../Transforms/test/fusion_preprocessing.mlir | 70 +++- .../Dialect/LinalgExt/IR/LinalgExtDialect.cpp | 5 +- .../LinalgExt/IR/LinalgExtInterfaces.td | 8 +- .../Dialect/LinalgExt/IR/LinalgExtOps.cpp | 15 + .../Dialect/LinalgExt/IR/LinalgExtOps.td | 5 +- .../Dialect/LinalgExt/IR/test/invalid.mlir | 13 + .../Dialect/LinalgExt/Transforms/BUILD.bazel | 2 + .../LinalgExt/Transforms/CMakeLists.txt | 2 + .../LinalgExt/Transforms/ReshapeFusion.cpp | 347 ++++++++++++++++++ .../Transforms/TilingInterfaceImpl.cpp | 23 ++ .../Dialect/LinalgExt/Transforms/Transforms.h | 16 + .../LinalgExt/Transforms/test/tiling.mlir | 51 +++ .../shark-test-suite-models/sdxl/test_unet.py | 2 +- 21 files changed, 778 insertions(+), 48 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/Flow/Transforms/test/attention_fuse_by_expansion.mlir create mode 100644 compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp create mode 100644 compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 48ca9ecdbfe4..7c47e313468a 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -332,13 +332,13 @@ jobs: run: | source ${VENV_DIR}/bin/activate pytest ./experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py \ - --goldentime-rocm-e2e-ms 1450.0 \ - --goldentime-rocm-unet-ms 370.0 \ + --goldentime-rocm-e2e-ms 1616.0 \ + --goldentime-rocm-unet-ms 419.0 \ --goldentime-rocm-clip-ms 18.5 \ - --goldentime-rocm-vae-ms 315.0 \ - --goldendispatch-rocm-unet 1691 \ + --goldentime-rocm-vae-ms 337.0 \ + --goldendispatch-rocm-unet 1551 \ --goldendispatch-rocm-clip 1225 \ - --goldendispatch-rocm-vae 248 \ + --goldendispatch-rocm-vae 247 \ --goldensize-rocm-unet-bytes 2280000 \ --goldensize-rocm-clip-bytes 860000 \ --goldensize-rocm-vae-bytes 840000 \ @@ -354,13 +354,13 @@ jobs: run: | source ${VENV_DIR}/bin/activate pytest ./experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py \ - --goldentime-rocm-e2e-ms 325.0 \ - --goldentime-rocm-unet-ms 77.0 \ + --goldentime-rocm-e2e-ms 372.0 \ + --goldentime-rocm-unet-ms 95.0 \ --goldentime-rocm-clip-ms 15.5 \ - --goldentime-rocm-vae-ms 74.0 \ - --goldendispatch-rocm-unet 1691 \ + --goldentime-rocm-vae-ms 80.0 \ + --goldendispatch-rocm-unet 1551 \ --goldendispatch-rocm-clip 1225 \ - --goldendispatch-rocm-vae 248 \ + --goldendispatch-rocm-vae 247 \ --goldensize-rocm-unet-bytes 2270000 \ --goldensize-rocm-clip-bytes 860000 \ --goldensize-rocm-vae-bytes 840000 \ diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp index 0e755870aa09..422ace837b9d 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BubbleUpExpandShapes.cpp @@ -15,10 +15,12 @@ #include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -82,6 +84,9 @@ void BubbleUpExpandShapesPass::runOnOperation() { }; linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns, bubbleUpExpansionControlFn); + LinalgExt::populateFoldReshapeOpsByExpansionPatterns( + bubbleExpandShapePatterns, bubbleUpExpansionControlFn); + // Add patterns to do some additional cleanup (on top of canonicalizations // that can be done later) of reshape ops. tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp index e39096b9847f..553a097c7875 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp @@ -744,6 +744,11 @@ isFusableWithProducer(OpOperand &operand, return true; } + // Don't fuse attention with it's producer + if (isa(consumer)) { + return false; + } + if (isPackLikeOp(consumer)) { return TypeSwitch(producer) .Case([&](auto padOp) { return true; }) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp index f95614d0149e..28932b3426af 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" @@ -41,23 +42,36 @@ namespace { // ElementwiseOpInterchangePattern //===----------------------------------------------------------------------===// +// If possible, interchange indexing maps to make input maps all identity. struct ElementwiseOpInterchangePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::GenericOp genericOp, PatternRewriter &rewriter) const override { - if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1) + if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1 || + genericOp.getNumDpsInputs() == 0) return failure(); - AffineMap indexingMap = genericOp.getIndexingMapsArray().back(); - if (indexingMap.isIdentity()) + // All input maps must be equal and non-identity. All maps, including + // output, must be be permutations. Permutation maps are checked by + // isElementwise but may be removed. + AffineMap inputMap = genericOp.getIndexingMapsArray().front(); + auto *initOperand = genericOp.getDpsInitOperand(0); + if (inputMap.isIdentity() || !inputMap.isPermutation() || + !genericOp.getMatchingIndexingMap(initOperand).isPermutation()) { return failure(); + } + for (auto *operand : genericOp.getDpsInputOperands()) { + if (genericOp.getMatchingIndexingMap(operand) != inputMap) { + return failure(); + } + } - ArrayRef exprs = indexingMap.getResults(); + // Make all inputs identity. + ArrayRef exprs = inputMap.getResults(); auto perm = llvm::map_to_vector(exprs, [](AffineExpr e) -> unsigned { return cast(e).getPosition(); }); - return linalg::interchangeGenericOp(rewriter, genericOp, perm); } }; @@ -210,6 +224,7 @@ struct FusionPreprocessingPass // operand shapes. memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); + tensor::populateFoldIntoPackAndUnpackPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel index a9d3ccd7446f..ee4f9369d114 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel @@ -17,6 +17,7 @@ iree_lit_test_suite( srcs = enforce_glob( [ "annotate_dispatches.mlir", + "attention_fuse_by_expansion.mlir", "capture_dispatch_dynamic_dims.mlir", "capture_scf_for_dynamic_dims.mlir", "cleanup_tensor_shapes.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt index 9c6d0a122f1d..a88bb4ff46f5 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt @@ -15,6 +15,7 @@ iree_lit_test_suite( lit SRCS "annotate_dispatches.mlir" + "attention_fuse_by_expansion.mlir" "capture_dispatch_dynamic_dims.mlir" "capture_scf_for_dynamic_dims.mlir" "cleanup_tensor_shapes.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/attention_fuse_by_expansion.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/attention_fuse_by_expansion.mlir new file mode 100644 index 000000000000..b243b5998b9f --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/attention_fuse_by_expansion.mlir @@ -0,0 +1,127 @@ +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-flow-bubble-up-expand-shapes, canonicalize, cse, canonicalize))" %s | FileCheck %s + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + +util.func public @attention_static(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16) -> tensor<2x10x4096x64xf16> { + %0 = tensor.empty() : tensor<20x4096x64xf16> + %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> + %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16> + util.return %expanded : tensor<2x10x4096x64xf16> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> +// CHECK: func public @attention_static( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16> +// CHECK-SAME: %[[ARG3:.+]]: f16) +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x4096x64xf16> +// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16] +// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16] +// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64] +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] : +// CHECK-SAME: outs(%[[EMPTY]] : +// CHECK: util.return %[[ATTENTION]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + +util.func public @attention_expand_all(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16) -> tensor<2x10x2048x2x2x32xf16> { + %0 = tensor.empty() : tensor<20x4096x64xf16> + %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> + %expanded = tensor.expand_shape %1 [[0, 1], [2, 3], [4, 5]] output_shape [2, 10, 2048, 2, 2, 32] : tensor<20x4096x64xf16> into tensor<2x10x2048x2x2x32xf16> + util.return %expanded : tensor<2x10x2048x2x2x32xf16> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d7)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d6, d7)> +// CHECK: func public @attention_expand_all( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16> +// CHECK-SAME: %[[ARG3:.+]]: f16) +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x2048x2x2x32xf16> +// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3], [4]] output_shape [2, 10, 2048, 2, 16] : tensor<20x4096x16xf16> into tensor<2x10x2048x2x16xf16> +// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3]] output_shape [2, 10, 1024, 16] : tensor<20x1024x16xf16> into tensor<2x10x1024x16xf16> +// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [2, 10, 1024, 2, 32] : tensor<20x1024x64xf16> into tensor<2x10x1024x2x32xf16> +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] : +// CHECK-SAME: outs(%[[EMPTY]] : +// CHECK: util.return %[[ATTENTION]] + +// ----- + + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + +util.func public @attention_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: f16) -> tensor<2x?x?x?xf16> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %d2 = tensor.dim %arg0, %c2 : tensor + %d3 = tensor.dim %arg1, %c1 : tensor + %d4 = tensor.dim %arg2, %c2 : tensor + %0 = tensor.empty(%d0, %d1, %d4) : tensor + %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor, tensor, tensor, f16) outs(%0 : tensor) -> tensor + %split = arith.divsi %d0, %c2 : index + %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4] + : tensor into tensor<2x?x?x?xf16> + util.return %expanded : tensor<2x?x?x?xf16> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<()[s0] -> (s0 floordiv 2)> +// CHECK: func public @attention_dynamic( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG3:.+]]: f16) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]] +// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]] +// CHECK-DAG: %[[SPLIT0:.+]] = arith.divui %[[D0]] +// CHECK-DAG: %[[VAL:.+]] = affine.apply #[[MAP4]]()[%[[D0]]] +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[VAL]], %[[D1]], %[[D4]]) : tensor<2x?x?x?xf16> +// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D1]], %[[D2]]] +// CHECK-DAG: %[[D5:.+]] = tensor.dim %[[ARG1]], %[[C0]] +// CHECK-DAG: %[[D6:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[D7:.+]] = tensor.dim %[[ARG1]], %[[C2]] +// CHECK-DAG: %[[SPLIT1:.+]] = arith.divui %[[D5]], %[[C2]] +// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT1]], %[[D6]], %[[D7]]] +// CHECK-DAG: %[[D8:.+]] = tensor.dim %[[ARG2]], %[[C0]] +// CHECK-DAG: %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C1]] +// CHECK-DAG: %[[SPLIT2:.+]] = arith.divui %[[D8]], %[[C2]] +// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT2]], %[[D9]], %[[D4]]] +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] : +// CHECK-SAME: outs(%[[EMPTY]] : +// CHECK: util.return %[[ATTENTION]] diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir index 4cca860d7647..25385da8d865 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir @@ -1,9 +1,9 @@ -// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-convert-dispatch-regions-to-workgroups), cse, canonicalize, cse)" %s | FileCheck %s +// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions), cse, canonicalize, cse)" %s | FileCheck %s #map = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -util.func public @linalgext_scatter_fusion() -> tensor<8192x16x8x128xf32> { +util.func public @linalgext_scatter_dispatch() -> tensor<8192x16x8x128xf32> { %0 = tensor.empty() : tensor<4x1xi32> %1 = tensor.empty() : tensor<4x1xi64> %2 = tensor.empty() : tensor<4x1x16x8x128xf32> @@ -36,22 +36,24 @@ util.func public @linalgext_scatter_fusion() -> tensor<8192x16x8x128xf32> { util.return %9 : tensor<8192x16x8x128xf32> } -// CHECK: util.func public @linalgext_scatter_fusion -// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups -// CHECK: %[[INDICES:.+]] = linalg.generic -// CHECK: %[[UPDATE:.+]] = linalg.generic -// CHECK: %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>) -// CHECK: flow.dispatch.workgroups -// CHECK: %[[GEN2:.+]] = linalg.generic -// CHECK-SAME: ins(%[[INPUT:.+]] : tensor<8192x16x8x128xf32>) +// CHECK-LABEL: util.func public @linalgext_scatter_dispatch +// CHECK: %[[RESULT:.+]] = flow.dispatch.region +// CHECK: %[[INDICES:.+]] = linalg.generic +// CHECK: %[[UPDATE:.+]] = linalg.generic +// CHECK: %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>) +// CHECK: flow.return %[[SCATTER_RESULT]] +// CHECK: flow.dispatch.region +// CHECK: %[[GEN2:.+]] = linalg.generic +// CHECK-SAME: ins(%[[INPUT:.+]] : tensor<8192x16x8x128xf32>) +// CHECK: flow.return %[[GEN2]] // ----- #map = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -util.func public @linalgext_scatter_fusion() -> tensor<8192x16x8x128xf32> { +util.func public @linalgext_scatter_clone() -> tensor<8192x16x8x128xf32> { %6 = tensor.empty() : tensor<4x1xi32> %2 = tensor.empty() : tensor<4x1x16x8x128xf32> %4 = tensor.empty() : tensor<10x8192x16x8x128xf32> @@ -69,8 +71,55 @@ util.func public @linalgext_scatter_fusion() -> tensor<8192x16x8x128xf32> { util.return %8 : tensor<8192x16x8x128xf32> } -// CHECK: util.func public @linalgext_scatter_fusion -// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups -// CHECK: %[[OUTS:.+]] = tensor.extract_slice -// CHECK: %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: outs(%[[OUTS]] : tensor<8192x16x8x128xf32>) +// CHECK-LABEL: util.func public @linalgext_scatter_clone +// CHECK: %[[RESULT:.+]] = flow.dispatch.region +// CHECK: %[[OUTS:.+]] = tensor.extract_slice +// CHECK: %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: outs(%[[OUTS]] : tensor<8192x16x8x128xf32>) +// CHECK: flow.return %[[SCATTER_RESULT]] + +// ----- + +util.func public @attention_dispatch(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: f16, %arg4: tensor, %arg5: tensor, %arg6: tensor) -> tensor { + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor) outs(%arg4 : tensor) { + ^bb0(%in: f16, %out: f16): + %5 = arith.mulf %in, %in : f16 + linalg.yield %5 : f16 + } -> tensor + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor) outs(%arg4 : tensor) { + ^bb0(%in: f16, %out: f16): + %5 = arith.mulf %in, %in : f16 + linalg.yield %5 : f16 + } -> tensor + %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg2 : tensor) outs(%arg4 : tensor) { + ^bb0(%in: f16, %out: f16): + %5 = arith.mulf %in, %in : f16 + linalg.yield %5 : f16 + } -> tensor + + %3 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%0, %1, %2, %arg3 : tensor, tensor, tensor, f16) outs(%arg4 : tensor) -> tensor + + %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3 : tensor) outs(%arg4 : tensor) { + ^bb0(%in: f16, %out: f16): + %5 = arith.mulf %in, %in : f16 + linalg.yield %5 : f16 + } -> tensor + util.return %4 : tensor +} + +// CHECK-LABEL: util.func public @attention_dispatch +// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region +// CHECK-NEXT: %[[GEN0:.+]] = linalg.generic +// CHECK: flow.return %[[GEN0]] +// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region +// CHECK-NEXT: %[[GEN1:.+]] = linalg.generic +// CHECK: flow.return %[[GEN1]] +// CHECK: %[[DISPATCH2:.+]] = flow.dispatch.region +// CHECK-NEXT: %[[GEN2:.+]] = linalg.generic +// CHECK: flow.return %[[GEN2]] +// CHECK: %[[RESULT:.+]] = flow.dispatch.region +// CHECK: %[[ATTN:.+]] = iree_linalg_ext.attention +// CHECK-SAME: ins(%[[DISPATCH0]], %[[DISPATCH1]], %[[DISPATCH2]] +// CHECK: %[[GEN2:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ATTN]] +// CHECK: flow.return %[[GEN2]] diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir index b1865bc9c803..eacef6a24f55 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir @@ -117,11 +117,11 @@ util.func public @fuse_generic_gather2( // ----- -#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)> -util.func @output_transpose_map(%arg0: tensor<2x128x128x320xf32>) -> tensor<2x320x128x128xf16> { +#ident = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#perm = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)> +util.func @single_input_interchange(%arg0: tensor<2x128x128x320xf32>) -> tensor<2x320x128x128xf16> { %0 = tensor.empty() : tensor<2x320x128x128xf16> - %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x128x128x320xf32>) outs(%0 : tensor<2x320x128x128xf16>) { + %1 = linalg.generic {indexing_maps = [#perm, #ident], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x128x128x320xf32>) outs(%0 : tensor<2x320x128x128xf16>) { ^bb0(%in: f32, %out: f16): %2 = arith.truncf %in : f32 to f16 linalg.yield %2 : f16 @@ -129,8 +129,62 @@ util.func @output_transpose_map(%arg0: tensor<2x128x128x320xf32>) -> tensor<2x32 util.return %1 : tensor<2x320x128x128xf16> } -// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> -// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-LABEL: util.func public @output_transpose_map +// CHECK-DAG: #[[$IDENT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$PERM_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)> +// CHECK-LABEL: util.func public @single_input_interchange +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x128x128x320xf32> +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x320x128x128xf16> // CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-SAME: indexing_maps = [#[[$IDENT_MAP]], #[[$PERM_MAP]]] +// CHECK-SAME: ins(%[[ARG0]] : tensor<2x128x128x320xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x320x128x128xf16>) + +// ----- + +#ident = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#perm = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)> +util.func @multi_input_interchange(%arg0: tensor<2x128x128x320xf32>) -> tensor<2x320x128x128xf16> { + %0 = tensor.empty() : tensor<2x320x128x128xf16> + %1 = linalg.generic {indexing_maps = [#perm, #perm, #ident], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg0 : tensor<2x128x128x320xf32>, tensor<2x128x128x320xf32>) outs(%0 : tensor<2x320x128x128xf16>) { + ^bb0(%in: f32, %in_1: f32, %out: f16): + %2 = arith.addf %in, %in_1 : f32 + %3 = arith.truncf %2 : f32 to f16 + linalg.yield %3 : f16 + } -> tensor<2x320x128x128xf16> + util.return %1 : tensor<2x320x128x128xf16> +} + +// CHECK-DAG: #[[$IDENT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$PERM_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)> +// CHECK-LABEL: util.func public @multi_input_interchange +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x128x128x320xf32> +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x320x128x128xf16> +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$IDENT_MAP]], #[[$IDENT_MAP]], #[[$PERM_MAP]]] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG0]] : tensor<2x128x128x320xf32>, tensor<2x128x128x320xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x320x128x128xf16>) + +// ----- + +#ident = affine_map<(d0, d1) -> (d0, d1)> +#perm0 = affine_map<(d0, d1) -> (d1, d0)> +util.func @multi_input_no_interchange(%arg0: tensor<10x10xf32>) -> tensor<10x10xf16> { + %0 = tensor.empty() : tensor<10x10xf16> + %1 = linalg.generic {indexing_maps = [#ident, #perm0, #perm0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg0 : tensor<10x10xf32>, tensor<10x10xf32>) outs(%0 : tensor<10x10xf16>) { + ^bb0(%in: f32, %in_1: f32, %out: f16): + %2 = arith.addf %in, %in_1 : f32 + %3 = arith.truncf %2 : f32 to f16 + linalg.yield %3 : f16 + } -> tensor<10x10xf16> + util.return %1 : tensor<10x10xf16> +} + +// CHECK-DAG: #[[$IDENT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$PERM_MAP0:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-LABEL: util.func public @multi_input_no_interchange +// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf32> +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<10x10xf16> +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$IDENT_MAP]], #[[$PERM_MAP0]], #[[$PERM_MAP0]]] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG0]] : tensor<10x10xf32>, tensor<10x10xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<10x10xf16>) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp index 1ec216b39a5c..5996310e10c3 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp @@ -97,10 +97,7 @@ struct LinalgFusionOpInterfaceAdapter return (llvm::cast(op).getMatchingIndexingMap(operand)); } - SmallVector getIndexingMaps(mlir::Operation *op) const { - // Note: this is different from linalg's implementation - // of `getIndexingMaps`. Call interface methods to get - // the vector of indexing maps for operands and results. + SmallVector getIndexingMapsArray(mlir::Operation *op) const { auto inputMaps = getIndexingMapsForOperands(op); llvm::append_range(inputMaps, getIndexingMapsForResults(op)); return inputMaps; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td index 8cbd1a4d21dd..596a925a5101 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td @@ -73,7 +73,7 @@ def LinalgFusionInterface : OpInterface<"LinalgFusionOpInterface"> { operand or result does not have an indexing map representation. }], /*retTy=*/"SmallVector", - /*methodName=*/"getIndexingMaps", + /*methodName=*/"getIndexingMapsArray", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -149,7 +149,11 @@ def LinalgFusionInterface : OpInterface<"LinalgFusionOpInterface"> { /*methodBody=*/"", /*defaultImplementation=*/[{ assert(opOperand->getOwner() == $_op); - return $_op.getIndexingMapsForOperands()[opOperand->getOperandNumber()]; + if(opOperand->getOperandNumber() >= $_op.getNumDpsInputs()){ + return $_op.getIndexingMapsForResults()[opOperand->getOperandNumber() - $_op.getNumDpsInputs()]; + }else { + return $_op.getIndexingMapsForOperands()[opOperand->getOperandNumber()]; + } }] >, diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 9f9590b29e76..37d2133f5b8a 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -1214,6 +1214,9 @@ LogicalResult AttentionOp::verify() { SmallVector indexingMaps = attnOp.getIndexingMapsArray(); FailureOr maybeOpInfo = AttentionOpDetail::get(indexingMaps); + if (failed(maybeOpInfo)) { + return attnOp->emitOpError("failed to verify op's indexing maps"); + } FloatType scaleElementType = dyn_cast(getScale().getType()); if (!scaleElementType) { @@ -1319,6 +1322,18 @@ SmallVector AttentionOp::getStaticLoopRanges() { return bounds; } +SmallVector AttentionOp::getIndexingMapsForOperands() { + auto maps = getIndexingMapsArray(); + return SmallVector(maps.begin(), + maps.begin() + getNumDpsInputs() - 1); +} + +SmallVector AttentionOp::getIndexingMapsForResults() { + auto maps = getIndexingMapsArray(); + return SmallVector(maps.begin() + getNumDpsInputs() - 1, + maps.end()); +} + //===----------------------------------------------------------------------===// // OnlineAttentionOp //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index e99b4c660733..c29245269530 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -453,12 +453,15 @@ def IREELinalgExt_TopkOp : IREELinalgExt_Op<"topk",[ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", [DeclareOpInterfaceMethods, DestinationStyleOpInterface, LinalgExtInterface, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + "getTiledImplementation", + "generateResultTileValue"]>]> { let summary = "Attention operator"; let description = [{ Computes the scaled dot product attention function: diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir index 1d0280bc75ee..2f8d8efe9ded 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir @@ -746,3 +746,16 @@ func.func @illegal_attention_inputs(%query: tensor<192x1024x64xf32>, %key: tenso ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> return %1 : tensor<192x1024x64xf32> } + +// ----- + +func.func @attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> { + %0 = tensor.empty() : tensor<192x1024x64xf32> + %scale = arith.constant 1.0 : f32 + // expected-error @below {{'iree_linalg_ext.attention' op failed to verify op's indexing maps}} + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> + return %1 : tensor<192x1024x64xf32> +} diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel index 3bba99a362d5..e3b6bd94bdb1 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel @@ -39,6 +39,7 @@ iree_compiler_cc_library( "DecomposeWinogradPass.cpp", "PadContractionToBlockSize.cpp", "Passes.cpp", + "ReshapeFusion.cpp", "SplitReduction.cpp", "TileAttention.cpp", "TilingInterfaceImpl.cpp", @@ -46,6 +47,7 @@ iree_compiler_cc_library( hdrs = [ "Passes.h", "Passes.h.inc", + "Transforms.h", ], deps = [ ":PassesIncGen", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt index aecda4b79ff9..f4a1628739d3 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt @@ -25,6 +25,7 @@ iree_cc_library( HDRS "Passes.h" "Passes.h.inc" + "Transforms.h" SRCS "AggregatedOpInterfaceImpl.cpp" "ConvertConv2DToIm2ColOp.cpp" @@ -35,6 +36,7 @@ iree_cc_library( "DecomposeWinogradPass.cpp" "PadContractionToBlockSize.cpp" "Passes.cpp" + "ReshapeFusion.cpp" "SplitReduction.cpp" "TileAttention.cpp" "TilingInterfaceImpl.cpp" diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp new file mode 100644 index 000000000000..a354a0dfb023 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp @@ -0,0 +1,347 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// +// The content of this file is adapted from linalg's ElemenwiseOpFusion.cpp and +// modified to work with LinalgExt ops, specifically `LinalgExt::AttentionOp`. + +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" + +namespace mlir::iree_compiler::IREE::LinalgExt { + +namespace { + +/// Information needed to expand an operation to fold the reshape with +/// it. +class ExpansionInfo { +public: + // Computes the mapping from original dimensions of the op to the dimensions + // of the expanded op given the `indexingMap` of the fused operand/result of + // the op, the `reassocationMaps` of the reshape op and the shape of + // the expanded op. + template + LogicalResult compute(OpTy op, OpOperand *fusableOpOperand, + ArrayRef reassociationMaps, + ArrayRef expandedShape, + ArrayRef collapsedShape, + PatternRewriter &rewriter); + unsigned getOrigOpNumDims() const { return reassociation.size(); } + unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } + ReassociationIndicesRef getExpandedDims(unsigned i) const { + return reassociation[i]; + } + ArrayRef getExpandedShapeOfDim(unsigned i) const { + return expandedShapeMap[i]; + } + ArrayRef getOriginalShape() const { return originalLoopExtent; } + +private: + /// Reassociation from the dimensions in the original operation to the + /// dimension of the expanded operation. + SmallVector reassociation; + /// Mapping from extent of loops in the original operation, to the extent of + /// loops in the expanded operation. + SmallVector> expandedShapeMap; + /// Extent of the loop in the original operation. + SmallVector originalLoopExtent; + unsigned expandedOpNumDims; +}; +} // namespace + +template +LogicalResult ExpansionInfo::compute(OpTy op, OpOperand *fusableOpOperand, + ArrayRef reassociationMaps, + ArrayRef expandedShape, + ArrayRef collapsedShape, + PatternRewriter &rewriter) { + if (reassociationMaps.empty()) + return failure(); + AffineMap fusedIndexMap = op.getMatchingIndexingMap(fusableOpOperand); + SmallVector originalLoopRange = op.getStaticLoopRanges(); + originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end()); + + reassociation.clear(); + expandedShapeMap.clear(); + // Compute the number of dimension in the expanded op that correspond to each + // dimension of the original op. + SmallVector numExpandedDims(fusedIndexMap.getNumDims(), 1); + expandedShapeMap.resize(fusedIndexMap.getNumDims()); + for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { + unsigned pos = cast(resultExpr.value()).getPosition(); + AffineMap foldedDims = reassociationMaps[resultExpr.index()]; + numExpandedDims[pos] = foldedDims.getNumResults(); + ArrayRef shape = + expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); + expandedShapeMap[pos].assign(shape.begin(), shape.end()); + } + // The remaining dimensions remain the same. + for (unsigned i : llvm::seq(0, fusedIndexMap.getNumDims())) + if (expandedShapeMap[i].empty()) + expandedShapeMap[i] = {originalLoopExtent[i]}; + + // Compute reassociation map from the original op to the expanded op. + unsigned sum = 0; + reassociation.reserve(fusedIndexMap.getNumDims()); + for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) { + auto seq = llvm::seq(sum, sum + numFoldedDim.value()); + reassociation.emplace_back(seq.begin(), seq.end()); + sum += numFoldedDim.value(); + } + expandedOpNumDims = sum; + return success(); +} + +static AffineMap +getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, + const ExpansionInfo &expansionInfo) { + SmallVector newExprs; + for (AffineExpr expr : indexingMap.getResults()) { + unsigned pos = cast(expr).getPosition(); + SmallVector expandedExprs = llvm::to_vector<4>( + llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { + return builder.getAffineDimExpr(static_cast(v)); + })); + newExprs.append(expandedExprs.begin(), expandedExprs.end()); + } + return AffineMap::get(expansionInfo.getExpandedOpNumDims(), + indexingMap.getNumSymbols(), newExprs, + builder.getContext()); +} + +static RankedTensorType getExpandedType(RankedTensorType originalType, + AffineMap indexingMap, + const ExpansionInfo &expansionInfo) { + SmallVector expandedShape; + for (AffineExpr expr : indexingMap.getResults()) { + unsigned dim = cast(expr).getPosition(); + auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); + expandedShape.append(dimExpansion.begin(), dimExpansion.end()); + } + return RankedTensorType::get(expandedShape, originalType.getElementType()); +} + +static SmallVector +getReassociationForExpansion(AffineMap indexingMap, + const ExpansionInfo &expansionInfo) { + SmallVector reassociation; + unsigned numReshapeDims = 0; + for (AffineExpr expr : indexingMap.getResults()) { + unsigned dim = cast(expr).getPosition(); + auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); + SmallVector indices = llvm::to_vector<2>( + llvm::seq(numReshapeDims, numReshapeDims + numExpandedDims)); + reassociation.emplace_back(std::move(indices)); + numReshapeDims += numExpandedDims; + } + return reassociation; +} + +template +static bool isFusableWithReshapeByDimExpansion(OpTy op, + OpOperand *fusableOpOperand) { + // Is fusable only if: + // - All the indexing maps for operands and results are projected + // permutations. + // - The fused tensor is not a scalar. + // - All the loops for the reshaped operand are parallel loops. + SmallVector iteratorTypes = op.getLoopIteratorTypes(); + AffineMap operandMap = op.getMatchingIndexingMap(fusableOpOperand); + return op.hasPureTensorSemantics() && + llvm::all_of( + op.getIndexingMapsArray(), + [](AffineMap map) { return map.isProjectedPermutation(); }) && + operandMap.getNumResults() > 0 && + llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) { + return iteratorTypes[cast(expr).getPosition()] == + utils::IteratorType::parallel; + }); +} + +static std::optional> fuseAttentionWithReshapeByExpansion( + AttentionOp attentionOp, Operation *reshapeOp, OpOperand *fusableOpOperand, + PatternRewriter &rewriter) { + assert(isFusableWithReshapeByDimExpansion(attentionOp, fusableOpOperand) && + "preconditions for fuse operation failed"); + + Location loc = attentionOp.getLoc(); + // Check if reshape is expanding or collapsing. + auto expandingReshapeOp = dyn_cast(*reshapeOp); + auto collapsingReshapeOp = dyn_cast(*reshapeOp); + bool isExpanding = (expandingReshapeOp != nullptr); + RankedTensorType expandedType = isExpanding + ? expandingReshapeOp.getResultType() + : collapsingReshapeOp.getSrcType(); + RankedTensorType collapsedType = isExpanding + ? expandingReshapeOp.getSrcType() + : collapsingReshapeOp.getResultType(); + + ExpansionInfo expansionInfo; + if (failed(expansionInfo.compute( + attentionOp, fusableOpOperand, + isExpanding ? expandingReshapeOp.getReassociationMaps() + : collapsingReshapeOp.getReassociationMaps(), + expandedType.getShape(), collapsedType.getShape(), rewriter))) + return std::nullopt; + + SmallVector expandedOpIndexingMaps = llvm::to_vector<4>( + llvm::map_range(attentionOp.getIndexingMapsArray(), [&](AffineMap m) { + return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); + })); + + // Set insertion point to the attention op. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(attentionOp); + + SmallVector expandedOpOperands; + expandedOpOperands.reserve(attentionOp.getNumDpsInputs()); + for (OpOperand *opOperand : attentionOp.getDpsInputOperands()) { + if (opOperand == fusableOpOperand) { + expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc() + : collapsingReshapeOp.getSrc()); + continue; + } + if (auto opOperandType = + dyn_cast(opOperand->get().getType())) { + AffineMap indexingMap = attentionOp.getMatchingIndexingMap(opOperand); + RankedTensorType expandedOperandType = + getExpandedType(opOperandType, indexingMap, expansionInfo); + if (expandedOperandType != opOperand->get().getType()) { + // Reshape the operand to get the right type. + SmallVector reassociation = + getReassociationForExpansion(indexingMap, expansionInfo); + if (failed(reshapeLikeShapesAreCompatible( + [&](const Twine &msg) { + return rewriter.notifyMatchFailure(attentionOp, msg); + }, + opOperandType.getShape(), expandedOperandType.getShape(), + reassociation, + /*isExpandingReshape=*/true))) + return std::nullopt; + expandedOpOperands.push_back(rewriter.create( + loc, expandedOperandType, opOperand->get(), reassociation)); + continue; + } + } + expandedOpOperands.push_back(opOperand->get()); + } + + SmallVector outputs; + for (OpOperand &opOperand : attentionOp.getDpsInitsMutable()) { + AffineMap indexingMap = attentionOp.getMatchingIndexingMap(&opOperand); + auto opOperandType = cast(opOperand.get().getType()); + RankedTensorType expandedOutputType = + getExpandedType(opOperandType, indexingMap, expansionInfo); + if (expandedOutputType != opOperand.get().getType()) { + SmallVector reassociation = + getReassociationForExpansion(indexingMap, expansionInfo); + if (failed(reshapeLikeShapesAreCompatible( + [&](const Twine &msg) { + return rewriter.notifyMatchFailure(attentionOp, msg); + }, + opOperandType.getShape(), expandedOutputType.getShape(), + reassociation, + /*isExpandingReshape=*/true))) + return std::nullopt; + outputs.push_back(rewriter.create( + loc, expandedOutputType, opOperand.get(), reassociation)); + } else { + outputs.push_back(opOperand.get()); + } + } + + // Create a new `AttentionOp` that has the computed operands/indexing maps. + TypeRange resultTypes = ValueRange(outputs).getTypes(); + auto fusedOp = rewriter.create( + attentionOp.getLoc(), resultTypes, expandedOpOperands[0], + expandedOpOperands[1], expandedOpOperands[2], expandedOpOperands[3], + outputs, rewriter.getAffineMapArrayAttr(expandedOpIndexingMaps)); + + // Reshape the result values to their original shape if this is a collapsing + // reshape folded into its consumer. + SmallVector resultVals; + for (OpResult opResult : attentionOp->getOpResults()) { + int64_t resultNumber = opResult.getResultNumber(); + if (resultTypes[resultNumber] != opResult.getType()) { + SmallVector reassociation = + getReassociationForExpansion( + attentionOp.getIndexingMapsForResults()[resultNumber], + expansionInfo); + resultVals.push_back(rewriter.create( + attentionOp.getLoc(), opResult.getType(), + fusedOp->getResult(resultNumber), reassociation)); + } else { + resultVals.push_back(fusedOp->getResult(resultNumber)); + } + } + // Assuming a single result. + return resultVals; +} + +struct FoldReshapeWithAttentionOpByExpansion + : public OpRewritePattern { + FoldReshapeWithAttentionOpByExpansion(MLIRContext *context, + linalg::ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, + PatternRewriter &rewriter) const override { + auto producerResult = dyn_cast(reshapeOp.getSrc()); + if (!producerResult) { + return rewriter.notifyMatchFailure(reshapeOp, + "source not produced by an operation"); + } + + auto producer = dyn_cast(producerResult.getOwner()); + if (!producer) { + return rewriter.notifyMatchFailure(reshapeOp, + "producer is not an attention op"); + } + + if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) { + return rewriter.notifyMatchFailure(reshapeOp, + "fusion blocked by control function"); + } + + // Note: expand_shape can always be fused with attention, it is not checked + // as a precondition. It is asserted in `fuseWithReshapeByExpansion`. + std::optional> replacementValues = + fuseAttentionWithReshapeByExpansion( + producer, reshapeOp, + producer.getDpsInitOperand(producerResult.getResultNumber()), + rewriter); + if (!replacementValues) { + return rewriter.notifyMatchFailure(reshapeOp, + "fusion by expansion failed"); + } + + Value reshapeReplacement = + (*replacementValues)[cast(reshapeOp.getSrc()) + .getResultNumber()]; + if (auto collapseOp = + reshapeReplacement.getDefiningOp()) { + reshapeReplacement = collapseOp.getSrc(); + } + rewriter.replaceOp(reshapeOp, reshapeReplacement); + rewriter.replaceOp(producer, *replacementValues); + return success(); + } + linalg::ControlFusionFn controlFoldingReshapes; +}; + +void populateFoldReshapeOpsByExpansionPatterns( + RewritePatternSet &patterns, + const linalg::ControlFusionFn &controlFoldingReshapes) { + patterns.insert(std::make_unique( + patterns.getContext(), controlFoldingReshapes)); +} + +} // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp index 2565df8d6400..c0d86dc4114d 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp @@ -1782,6 +1782,29 @@ LogicalResult AttentionOp::getResultTilePosition( return success(); } +FailureOr +AttentionOp::generateResultTileValue(OpBuilder &builder, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes) { + // Input offsets and sizes here are from the POV of the outputMap. We need to + // normalize these offsets and size for it to be useful. + + // Initialize normalized offsets with 0s and normalized sizes with original + // size. + SmallVector iterationDomain(getIterationDomain(builder)); + SmallVector normalizedSizes = + llvm::map_to_vector(iterationDomain, [](Range x) { return x.size; }); + SmallVector normalizedOffsets(getIterationDomainRank(), + builder.getIndexAttr(0)); + ArrayRef outputDims = getOutputMap().getResults(); + for (int i = 0; i < outputDims.size(); i++) { + int dim = cast(outputDims[i]).getPosition(); + normalizedOffsets[dim] = offsets[i]; + normalizedSizes[dim] = sizes[i]; + } + return getTiledImplementation(builder, normalizedOffsets, normalizedSizes); +} + //===----------------------------------------------------------------------===// // OnlineAttentionOp //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h new file mode 100644 index 000000000000..4d8606025ed2 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h @@ -0,0 +1,16 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" + +namespace mlir::iree_compiler::IREE::LinalgExt { + +// Fold expand_shape ops with their producers (only `AttentionOp` supported) +void populateFoldReshapeOpsByExpansionPatterns( + RewritePatternSet &patterns, + const linalg::ControlFusionFn &controlFoldingReshapes); + +}; // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir index 67a189b7d24b..147d5925346e 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir @@ -1652,6 +1652,57 @@ module attributes { transform.with_named_sequence } { // ----- +func.func @attention_fusion( + %query: tensor<2x10x4096x64xf16>, + %key: tensor<2x10x4096x64xf16>, + %value: tensor<2x10x4096x64xf16>, + %scale : f16, %bias : tensor<10x64xf16>) -> tensor<2x10x4096x64xf16> { + %0 = tensor.empty() : tensor<2x10x4096x64xf16> + %1 = iree_linalg_ext.attention { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} + ins(%query, %key, %value, %scale : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, f16) + outs(%0 : tensor<2x10x4096x64xf16>) -> tensor<2x10x4096x64xf16> + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%1, %bias : tensor<2x10x4096x64xf16>, tensor<10x64xf16>) + outs(%0 : tensor<2x10x4096x64xf16>) { + ^bb0(%b0 : f16, %b1 : f16, %b2 : f16): + %3 = arith.addf %b0, %b1 : f16 + linalg.yield %3 : f16 + } -> tensor<2x10x4096x64xf16> + return %2 : tensor<2x10x4096x64xf16> +} +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["linalg.generic"]} in %module_op : (!transform.any_op) -> !transform.any_op + %2, %loops = transform.structured.tile_using_forall %1 tile_sizes [1, 1, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %_, %__ = transform.structured.fuse_into_containing_op %0 into %loops : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-LABEL: func @attention_fusion( +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x4096x64xf16> +// CHECK: %[[RESULT:.+]] = scf.forall +// CHECK-SAME: shared_outs(%[[OUTS:.+]] = %[[EMPTY]]) +// CHECK: %[[EMPTY_SLICE:.+]] = tensor.extract_slice %[[EMPTY]] +// CHECK: %[[ATTENTION_SLICE:.+]] = iree_linalg_ext.attention +// CHECK-SAME: outs(%[[EMPTY_SLICE]] : +// CHECK: %[[OUTS_SLICE:.+]] = tensor.extract_slice %[[OUTS]] +// CHECK: %[[BIAS_SLICE:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ATTENTION_SLICE]], +// CHECK-SAME: outs(%[[OUTS_SLICE]] : +// CHECK: tensor.parallel_insert_slice %[[BIAS_SLICE]] into %[[OUTS]] +// CHECK: return %[[RESULT]] + +// ----- + #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py index bc0d55b9987c..8509dcba39d8 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py @@ -177,7 +177,7 @@ def test_run_unet_rocm(SDXL_UNET_COMMON_RUN_FLAGS, sdxl_unet_real_weights): args=[ f"--parameters=model={sdxl_unet_real_weights.path}", f"--module={VmfbManager.sdxl_unet_rocm_pipeline_vmfb.path}", - "--expected_f16_threshold=0.7f", + "--expected_f16_threshold=0.705f", ] + SDXL_UNET_COMMON_RUN_FLAGS, )