From 6ac6be6b606d9d5b8a20acf562a8834d3500cda3 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 12 Aug 2024 17:27:35 -0400 Subject: [PATCH] [GlobalOpt] Improve unary elementwise propagation to consider broadcasted operands (#17903) For binary (or more operands) elementwise operations, if one of the operands is broadcasted or otherwise unaffected by a transposition, then it can effectively be treated like a unary elementwise operation for the purpose of propagation because propagating the transpose would introduce only one additional transpose on the input operand. This improves the unary elementwise propagation patterns to handle such cases. --- .github/workflows/pkgci_regression_test.yml | 4 +- .../PropagateLinalgTranspose.cpp | 207 ++++++++++++++---- .../test/propagate_linalg_transpose.mlir | 180 +++++++++++++++ 3 files changed, 352 insertions(+), 39 deletions(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 6347f93db575..ea18f86f599c 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -342,7 +342,7 @@ jobs: --goldentime-rocm-clip-ms 18.5 \ --goldentime-rocm-vae-ms 315.0 \ --goldendispatch-rocm-unet 1714 \ - --goldendispatch-rocm-clip 1569 \ + --goldendispatch-rocm-clip 1311 \ --goldendispatch-rocm-vae 248 \ --goldensize-rocm-unet-bytes 2280000 \ --goldensize-rocm-clip-bytes 860000 \ @@ -364,7 +364,7 @@ jobs: --goldentime-rocm-clip-ms 15.5 \ --goldentime-rocm-vae-ms 74.0 \ --goldendispatch-rocm-unet 1714 \ - --goldendispatch-rocm-clip 1569 \ + --goldendispatch-rocm-clip 1311 \ --goldendispatch-rocm-vae 248 \ --goldensize-rocm-unet-bytes 2270000 \ --goldensize-rocm-clip-bytes 860000 \ diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp index 2dea4ad2ae4a..8fe7ed2f5cff 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp @@ -644,23 +644,71 @@ class FuseTransposeWithLinalgOpConsumer bool allowGeneralizing = false; }; -bool isUnaryElementwiseGeneric(linalg::GenericOp genericOp) { - if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInputs() != 1 || - !linalg::isElementwise(genericOp)) { - return false; +static bool isIndexingMapAffectedByTransposeMap( + AffineMap indexingMap, ArrayRef iterationSpacePermutation) { + int64_t prevIdx = -1; + for (auto result : indexingMap.getResults()) { + int64_t idx = + iterationSpacePermutation[cast(result).getPosition()]; + // Verify that the relative ordering of indices in the map remain the same. + // If not, then the transposition affects the access order for the given + // map (and associated operand). + if (idx <= prevIdx) { + return true; + } + prevIdx = idx; } + return false; +} - // Skip transposes and broadcasts. Transposes make more sense to fuse - // rather than propagate through, and broadcasts are cheaper to transpose - // before broadcasting. - if (genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(0)) != - genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0))) { - return false; +// Finds a single DPS input operand of the given |genericOp| that is affected by +// the |iterationSpacePermutation|. In other words, the permutation changes the +// relative ordering of any of the dimensions of that input operand. +// +// For example, with permutation [1, 0, 2], affine map (d0, d1, d2) -> (d0, d1) +// is affected by the permutation because the first two dimensions are iterated +// in a different order while (d0, d1, d2) -> (d0, d2) is unaffected. +// +// If no such operand is found or there is more than one such operation, nullptr +// is returned. +static OpOperand * +getSingleTransposedInputOperand(linalg::GenericOp genericOp, + ArrayRef iterationSpacePermutation) { + OpOperand *operand = nullptr; + for (auto input : genericOp.getDpsInputOperands()) { + if (!isIndexingMapAffectedByTransposeMap( + genericOp.getMatchingIndexingMap(input), + iterationSpacePermutation)) { + continue; + } + if (operand) { + return nullptr; + } + operand = input; } - return true; + return operand; +} + +// Returns a new list of indexing maps that composes the iteration space +// permutation map |transposeMap| with all indexing maps of |genericOp| except +// for the |transposedInputIdx|'th operand. The unchanged operand is expected +// to have an explicit `linalg.transpose` op constructed for it so its map does +// not need to be updated. +static SmallVector +getTransposedIndexingMaps(linalg::GenericOp genericOp, + int64_t transposedInputIdx, AffineMap transposeMap) { + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + for (unsigned i = 0, e = genericOp.getNumDpsInputs(); i < e; ++i) { + if (i == transposedInputIdx) { + continue; + } + indexingMaps[i] = indexingMaps[i].compose(transposeMap); + } + return indexingMaps; } -// Sinks a transpose through the input of a unary elementwise operation. +// Sinks a transpose through the input of a elementwise operation where the +// transposition of the iteration space only affects a single input operand. class SinkTransposeThroughUnaryElementwiseInput : public OpRewritePattern { public: @@ -669,22 +717,57 @@ class SinkTransposeThroughUnaryElementwiseInput LogicalResult matchAndRewrite(linalg::GenericOp genericOp, PatternRewriter &rewriter) const override { if (!IREE::Flow::isNonNullAndOutsideDispatch(genericOp)) { - return failure(); + return rewriter.notifyMatchFailure(genericOp, "pre-formed dispatch"); } - if (!isUnaryElementwiseGeneric(genericOp)) { - return rewriter.notifyMatchFailure(genericOp, "not unary elementwise"); + if (!linalg::isElementwise(genericOp)) { + return rewriter.notifyMatchFailure(genericOp, "non-elementwise generic"); } - auto transposeOp = - genericOp.getDpsInputs()[0].getDefiningOp(); - if (!transposeOp) { - return rewriter.notifyMatchFailure(genericOp, "no transpose operand"); + if (genericOp.getNumDpsInits() != 1) { + return rewriter.notifyMatchFailure(genericOp, + "unimplemented: multiple results"); } - if (!transposeOp->hasOneUse()) { + AffineMap resultMap = + genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0)); + if (!resultMap.isIdentity()) { return rewriter.notifyMatchFailure( - genericOp, "do not propagate multi-use transpose"); + genericOp, "unimplemented: non-identity result map"); + } + + linalg::TransposeOp transposeOp; + OpOperand *inputOperand; + for (auto input : genericOp.getDpsInputOperands()) { + // Skip broadcasted operands and transposed operands. If the input is + // broadcasted then we would not want to propagate because that would + // do the transpose on larger data, and if transposed we would rather + // simply compose the transposes (handled in a separate pattern). + if (genericOp.getMatchingIndexingMap(input) != resultMap) { + continue; + } + + auto maybeTransposeOp = input->get().getDefiningOp(); + // Skip multi-use transposes. + if (!maybeTransposeOp || !maybeTransposeOp->hasOneUse()) { + continue; + } + + auto transposableInputOperand = getSingleTransposedInputOperand( + genericOp, maybeTransposeOp.getPermutation()); + // Skip if more than one operand is affected by the transpose. + if (transposableInputOperand != input) { + continue; + } + + transposeOp = maybeTransposeOp; + inputOperand = transposableInputOperand; + break; + } + + if (!transposeOp) { + return rewriter.notifyMatchFailure(genericOp, + "no single use transpose operand"); } ArrayRef perm = transposeOp.getPermutation(); @@ -694,18 +777,30 @@ class SinkTransposeThroughUnaryElementwiseInput Value newInit = createTransposeInit(rewriter, genericOp.getDpsInits()[0], invPerm); - // We do not need to update indexing maps because this is a unary - // elementwise op where the input and output maps are the same. Just - // replace the operands with transposed variants. - auto newGenericOp = mlir::clone(rewriter, genericOp, newInit.getType(), - {transposeOp.getInput(), newInit}); + // We do not need to update iterator types because this is an elementwise + // op. We just need to update the indexing maps of all other input operands + // by composing the transpose map. + AffineMap transposeMap = + AffineMap::getPermutationMap(perm, rewriter.getContext()); + SmallVector indexingMaps = getTransposedIndexingMaps( + genericOp, inputOperand->getOperandNumber(), transposeMap); + + SmallVector newOperands = genericOp->getOperands(); + newOperands[inputOperand->getOperandNumber()] = transposeOp.getInput(); + newOperands[genericOp.getDpsInitOperand(0)->getOperandNumber()] = newInit; + + auto newGenericOp = + mlir::clone(rewriter, genericOp, newInit.getType(), newOperands); + newGenericOp.setIndexingMapsAttr( + rewriter.getAffineMapArrayAttr(indexingMaps)); rewriter.replaceOp( genericOp, createTranspose(rewriter, newGenericOp->getResult(0), perm)); return success(); } }; -// Bubbles a transpose through the init of a unary elementwise operation. +// Bubbles a transpose through the init of a elementwise operation where the +// transposition of the iteration space only affects a single input operand. class BubbleTransposeThroughUnaryElementwiseDpsInit : public OpRewritePattern { public: @@ -715,33 +810,64 @@ class BubbleTransposeThroughUnaryElementwiseDpsInit PatternRewriter &rewriter) const override { auto genericOp = transposeOp.getInput().getDefiningOp(); if (!genericOp) { - return failure(); + return rewriter.notifyMatchFailure(transposeOp, "non-generic producer"); + } + + if (genericOp.getNumDpsInits() != 1) { + return rewriter.notifyMatchFailure(transposeOp, + "unimplemented: multiple results"); } + if (!IREE::Flow::isNonNullAndOutsideDispatch({genericOp, transposeOp})) { return failure(); } - if (!isUnaryElementwiseGeneric(genericOp)) { - return rewriter.notifyMatchFailure(genericOp, "not unary elementwise"); + if (!linalg::isElementwise(genericOp) || + !genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0)) + .isIdentity()) { + return rewriter.notifyMatchFailure(transposeOp, "not elementwise"); } if (!genericOp->hasOneUse()) { - return rewriter.notifyMatchFailure(genericOp, "not single user"); + return rewriter.notifyMatchFailure(transposeOp, "not single user"); } ArrayRef perm = transposeOp.getPermutation(); - Value newTranspose = - createTranspose(rewriter, genericOp.getOperand(0), perm); + auto invPerm = invertPermutationVector(perm); + + auto inputOperand = getSingleTransposedInputOperand(genericOp, invPerm); + if (!inputOperand || + !genericOp.getMatchingIndexingMap(inputOperand).isIdentity()) { + return rewriter.notifyMatchFailure( + genericOp, "no single transposable input operand"); + } + + Value newTranspose = createTranspose(rewriter, inputOperand->get(), perm); // Create a new empty init for the transposed generic. Value newInit = createTransposeInit(rewriter, genericOp.getDpsInits()[0], perm); + SmallVector newOperands = genericOp->getOperands(); + newOperands[inputOperand->getOperandNumber()] = newTranspose; + newOperands[genericOp.getDpsInitOperand(0)->getOperandNumber()] = newInit; + + AffineMap transposeMap = + AffineMap::getPermutationMap(invPerm, rewriter.getContext()); + + // We do not need to update iterator types because this is an elementwise + // op. We just need to update the indexing maps of all other input operands + // by composing the transpose map. + SmallVector indexingMaps = getTransposedIndexingMaps( + genericOp, inputOperand->getOperandNumber(), transposeMap); + // We do not need to update indexing maps because this is a unary // elementwise op where the input and output maps are the same. Just // replace the operands with transposed variants. - auto newGenericOp = mlir::clone(rewriter, genericOp, newInit.getType(), - {newTranspose, newInit}); + auto newGenericOp = + mlir::clone(rewriter, genericOp, newInit.getType(), newOperands); + newGenericOp.setIndexingMapsAttr( + rewriter.getAffineMapArrayAttr(indexingMaps)); rewriter.replaceOp(transposeOp, newGenericOp); return success(); } @@ -912,6 +1038,7 @@ void PropagateLinalgTransposePass::runOnOperation() { context, /*benefit=*/2); if (failed( applyPatternsAndFoldGreedily(funcOp, std::move(sinkingPatterns)))) { + funcOp.emitError("Transpose initial sinking patterns failed"); return signalPassFailure(); } } @@ -968,6 +1095,7 @@ void PropagateLinalgTransposePass::runOnOperation() { populateCommonCanonicalizationPatterns(context, bubblingPatterns); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(bubblingPatterns)))) { + funcOp.emitError("Transpose bubbling patterns failed"); return signalPassFailure(); } } @@ -1020,8 +1148,13 @@ void PropagateLinalgTransposePass::runOnOperation() { populateCommonCanonicalizationPatterns(context, sinkingPatterns); sinkingPatterns.add( context, /*benefit=*/2); - if (failed( - applyPatternsAndFoldGreedily(funcOp, std::move(sinkingPatterns)))) { + GreedyRewriteConfig config; + // TODO: This is inefficient. Consider rewriting this pass to use a + // worklist of just the transpose operations. + config.maxIterations = GreedyRewriteConfig::kNoLimit; + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(sinkingPatterns), + config))) { + funcOp.emitError("Transpose sinking patterns failed"); return signalPassFailure(); } } diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir index 939e650afc29..6b9571666808 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir @@ -485,3 +485,183 @@ util.func public @bubble_through_matmul(%lhs: tensor<16x16xf32>, // APROP-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] // APROP-SAME: outs(%[[EMPTY]] : tensor<16x16xf32>) // APROP: util.return %[[MATMUL]] + +// ----- + +util.func public @propagate_transpose_down_through_broadcast_elementwise(%arg0: tensor<3x4x2xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x3x4xf32> { + %empty = tensor.empty(): tensor<2x3x4xf32> + %transposed = linalg.transpose ins(%arg0 : tensor<3x4x2xf32>) + outs(%empty : tensor<2x3x4xf32>) permutation = [2, 0, 1] + %0 = linalg.generic {indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%transposed, %arg1 : tensor<2x3x4xf32>, tensor<3x4xf32>) + outs(%empty : tensor<2x3x4xf32>) { + ^bb0(%in: f32, %in1: f32, %out: f32): + %add = arith.addf %in, %in1 : f32 + linalg.yield %add : f32 + } -> tensor<2x3x4xf32> + util.return %0 : tensor<2x3x4xf32> +} + +// SINK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// SINK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// SINK-LABEL: util.func public @propagate_transpose_down_through_broadcast_elementwise +// SINK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<3x4x2xf32> +// SINK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<3x4xf32> +// SINK: %[[ELEM:.+]] = linalg.generic +// SINK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]] +// SINK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<3x4x2xf32>, tensor<3x4xf32> +// SINK: arith.addf +// SINK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ELEM]] : tensor<3x4x2xf32> +// SINK-SAME: outs({{.*}} : tensor<2x3x4xf32>) +// SINK-SAME: permutation = [2, 0, 1] +// SINK: util.return %[[TRANSPOSE]] : tensor<2x3x4xf32> + +// ----- + +util.func public @propagate_transpose_down_through_multi_operand_elementwise(%arg0: tensor<3x4x2xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x3x4xf32> { + %empty = tensor.empty(): tensor<2x3x4xf32> + %t1 = linalg.transpose ins(%arg0 : tensor<3x4x2xf32>) + outs(%empty : tensor<2x3x4xf32>) permutation = [2, 0, 1] + %empty2 = tensor.empty(): tensor<4x3xf32> + %t2 = linalg.transpose ins(%arg1 : tensor<3x4xf32>) + outs(%empty2 : tensor<4x3xf32>) permutation = [1, 0] + %0 = linalg.generic {indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%t2, %t1 : tensor<4x3xf32>, tensor<2x3x4xf32>) + outs(%empty : tensor<2x3x4xf32>) { + ^bb0(%in: f32, %in1: f32, %out: f32): + %add = arith.addf %in, %in1 : f32 + linalg.yield %add : f32 + } -> tensor<2x3x4xf32> + util.return %0 : tensor<2x3x4xf32> +} + +// Verify that it first selects the correct transpose to propagate and then +// fuses the transpose on the broadcasted operand. + +// SINK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// SINK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// SINK-LABEL: util.func public @propagate_transpose_down_through_multi_operand_elementwise +// SINK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<3x4x2xf32> +// SINK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<3x4xf32> +// SINK: %[[ELEM:.+]] = linalg.generic +// SINK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP1]]] +// SINK-SAME: ins(%[[ARG1]], %[[ARG0]] : tensor<3x4xf32>, tensor<3x4x2xf32> +// SINK: arith.addf +// SINK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ELEM]] : tensor<3x4x2xf32> +// SINK-SAME: outs({{.*}} : tensor<2x3x4xf32>) +// SINK-SAME: permutation = [2, 0, 1] +// SINK: util.return %[[TRANSPOSE]] : tensor<2x3x4xf32> + +// ----- + +util.func public @sink_transpose_down_to_broadcast_elementwise(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x3x4xf32> { + %empty = tensor.empty(): tensor<2x3x4xf32> + %transposed = linalg.transpose ins(%arg0 : tensor<3x4x2xf32>) + outs(%empty : tensor<2x3x4xf32>) permutation = [2, 0, 1] + %0 = linalg.generic {indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%transposed, %arg1 : tensor<2x3x4xf32>, tensor<2x4xf32>) + outs(%empty : tensor<2x3x4xf32>) { + ^bb0(%in: f32, %in1: f32, %out: f32): + %add = arith.addf %in, %in1 : f32 + linalg.yield %add : f32 + } -> tensor<2x3x4xf32> + util.return %0 : tensor<2x3x4xf32> +} + +// Verify that the transpose is fused rather than propagated because the +// broadcast operand would be affected by the transpose. + +// SINK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +// SINK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// SINK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// SINK-LABEL: util.func public @sink_transpose_down_to_broadcast_elementwise +// SINK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<3x4x2xf32> +// SINK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<2x4xf32> +// SINK: %[[ELEM:.+]] = linalg.generic +// SINK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] +// SINK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<3x4x2xf32>, tensor<2x4xf32> +// SINK: arith.addf +// SINK: util.return %[[ELEM]] : tensor<2x3x4xf32> + +// ----- + +util.func public @propagate_transpose_up_through_broadcast_elementwise(%arg0: tensor<2x3x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<3x4x2xf32> { + %empty = tensor.empty(): tensor<2x3x4xf32> + %0 = linalg.generic {indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<3x4xf32>) + outs(%empty : tensor<2x3x4xf32>) { + ^bb0(%in: f32, %in1: f32, %out: f32): + %add = arith.addf %in, %in1 : f32 + linalg.yield %add : f32 + } -> tensor<2x3x4xf32> + %empty1 = tensor.empty(): tensor<3x4x2xf32> + %transposed = linalg.transpose ins(%0 : tensor<2x3x4xf32>) + outs(%empty1 : tensor<3x4x2xf32>) permutation = [1, 2, 0] + util.return %transposed : tensor<3x4x2xf32> +} + +// BUBBLE-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// BUBBLE-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// BUBBLE-LABEL: util.func public @propagate_transpose_up_through_broadcast_elementwise +// BUBBLE-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3x4xf32> +// BUBBLE-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<3x4xf32> +// BUBBLE: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<2x3x4xf32> +// BUBBLE-SAME: outs({{.*}} : tensor<3x4x2xf32>) +// BUBBLE-SAME: permutation = [1, 2, 0] +// BUBBLE: %[[ELEM:.+]] = linalg.generic +// BUBBLE-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]] +// BUBBLE-SAME: ins(%[[TRANSPOSE]], %[[ARG1]] : tensor<3x4x2xf32>, tensor<3x4xf32> +// BUBBLE: arith.addf +// BUBBLE: util.return %[[ELEM]] : tensor<3x4x2xf32> + +// ----- + +util.func public @bubble_transpose_to_broadcast_elementwise(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x2xf32> { + %empty = tensor.empty(): tensor<2x3x4xf32> + %0 = linalg.generic {indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<2x4xf32>) + outs(%empty : tensor<2x3x4xf32>) { + ^bb0(%in: f32, %in1: f32, %out: f32): + %add = arith.addf %in, %in1 : f32 + linalg.yield %add : f32 + } -> tensor<2x3x4xf32> + %empty1 = tensor.empty(): tensor<3x4x2xf32> + %transposed = linalg.transpose ins(%0 : tensor<2x3x4xf32>) + outs(%empty1 : tensor<3x4x2xf32>) permutation = [1, 2, 0] + util.return %transposed : tensor<3x4x2xf32> +} + +// Verify that the transpose is fused rather than propagated because the +// broadcast operand would be affected by the transpose. + +// BUBBLE-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +// BUBBLE-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// BUBBLE-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// BUBBLE-LABEL: util.func public @bubble_transpose_to_broadcast_elementwise +// BUBBLE-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3x4xf32> +// BUBBLE-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<2x4xf32> +// BUBBLE: %[[ELEM:.+]] = linalg.generic +// BUBBLE-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] +// BUBBLE-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x4xf32> +// BUBBLE: arith.addf +// BUBBLE: util.return %[[ELEM]] : tensor<3x4x2xf32>