From 41115bba05960e563791ce6ed1af26093f4fab1e Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Tue, 26 Nov 2024 15:25:35 -0800 Subject: [PATCH] [Codegen] Bubble up Transpose attention V and try fuse with others before attention (#19250) Flash Attention transpose_V variant is significantly faster than the non-transpose_V variant. This is due to many matmul intrinsics being mmtb by default. Hence, doing FA transpose_V will allow for better/more contiguous reads from shared memory to register, improving the attention performance quite a bit. This PR exposes the attention_transposeV form by generating a linalg.transpose on the V during bubbling up of transpose S.T we can give the graph some opportunities to fuse the transpose-V to it's producer. I have also confirmed that if we do not find any producer, the transpose will indeed fuse back with the attenionOp. Hence worse case, we will get same perf as before this PR. Additionally, we modify elementwise op fusion to try fuse transpose with other ops before letting it get fused back into attention. --------- Signed-off-by: Stanley Winata --- .github/workflows/pkgci_regression_test.yml | 16 +-- .../attention_and_matmul_spec_punet.mlir | 71 ++++++++++++ .../Dialect/LinalgExt/Transforms/Transforms.h | 6 + .../LinalgExt/Transforms/TransposeFusion.cpp | 105 ++++++++++++++++++ .../DispatchCreation/ElementwiseOpFusion.cpp | 28 +++-- .../test/elementwise_op_fusion.mlir | 51 +++++++++ .../PropagateLinalgTranspose.cpp | 10 ++ .../test/propagate_linalg_transpose.mlir | 44 ++++++++ 8 files changed, 314 insertions(+), 17 deletions(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 8b0b4e0189d2..7a67778c0585 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -220,7 +220,7 @@ jobs: --goldentime-rocm-unet-ms 419.0 \ --goldentime-rocm-clip-ms 18.5 \ --goldentime-rocm-vae-ms 337.0 \ - --goldendispatch-rocm-unet 1531 \ + --goldendispatch-rocm-unet 1602 \ --goldendispatch-rocm-clip 1139 \ --goldendispatch-rocm-vae 246 \ --goldensize-rocm-unet-bytes 2280000 \ @@ -238,21 +238,21 @@ jobs: run: | source ${VENV_DIR}/bin/activate pytest ./experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py \ - --goldentime-rocm-e2e-ms 372.0 \ - --goldentime-rocm-unet-ms 95.0 \ + --goldentime-rocm-e2e-ms 330.0 \ + --goldentime-rocm-unet-ms 80.0 \ --goldentime-rocm-clip-ms 15.5 \ --goldentime-rocm-vae-ms 80.0 \ - --goldendispatch-rocm-unet 1531 \ + --goldendispatch-rocm-unet 1602 \ --goldendispatch-rocm-clip 1139 \ --goldendispatch-rocm-vae 246 \ --goldensize-rocm-unet-bytes 2270000 \ --goldensize-rocm-clip-bytes 860000 \ --goldensize-rocm-vae-bytes 840000 \ - --goldentime-rocm-punet-int8-fp16-ms 55 \ - --goldendispatch-rocm-punet-int8-fp16 1284 \ + --goldentime-rocm-punet-int8-fp16-ms 53 \ + --goldendispatch-rocm-punet-int8-fp16 1424 \ --goldensize-rocm-punet-int8-fp16-bytes 2560000 \ - --goldentime-rocm-punet-int8-fp8-ms 59 \ - --goldendispatch-rocm-punet-int8-fp8 1564 \ + --goldentime-rocm-punet-int8-fp8-ms 53 \ + --goldendispatch-rocm-punet-int8-fp8 1704 \ --goldensize-rocm-punet-int8-fp8-bytes 2800000 \ --rocm-chip gfx942 \ --log-cli-level=info \ diff --git a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir index a566203907e4..7b0944471990 100644 --- a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir +++ b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir @@ -208,6 +208,41 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran transform.yield %cont, %config : !transform.any_op, !transform.any_param } + + // Variant of matmul_like_Bx20x1024x64x1280_i8xi8xi32 from Transposed-V. + transform.named_sequence @match_matmul_like_Bx20x64x1024x1280_i8xi8xi32(%cont: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont { + ^bb0(%lhs: tensor, %rhs: tensor<20x64x1280xi8>, %out: tensor): + %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor<20x64x1280xi8>) + outs(%out : tensor) { + ^bb0(%in: i8, %in_0: i8, %acc: i32): + %18 = arith.extsi %in : i8 to i32 + %19 = arith.extsi %in_0 : i8 to i32 + %20 = arith.muli %18, %19 : i32 + %21 = arith.addi %acc, %20 : i32 + linalg.yield %21 : i32 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 2, + reduction = [0, 0, 0, 0, 128], + workgroup = [1, 1, 160, 64, 0]}>, + translation_info = #iree_codegen.translation_info> + }> + > -> !transform.any_param + transform.yield %cont, %config : !transform.any_op, !transform.any_param + } + transform.named_sequence @match_matmul_like_Bx20x64x64x2048_i8xi8xi32(%cont: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont { @@ -239,6 +274,38 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran transform.yield %cont, %config : !transform.any_op, !transform.any_param } + // Variant of matmul_like_Bx20x64x64x2048_i8xi8xi32 from Transposed-V. +transform.named_sequence @match_matmul_like_Bx20x64x64x2048_transposev_i8xi8xi32(%cont: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont { + ^bb0(%lhs: tensor, %rhs: tensor<20x64x2048xi8>, %out: tensor): + %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor<20x64x2048xi8>) + outs(%out : tensor) { + ^bb0(%in: i8, %in_0: i8, %acc: i32): + %18 = arith.extsi %in : i8 to i32 + %19 = arith.extsi %in_0 : i8 to i32 + %20 = arith.muli %18, %19 : i32 + %21 = arith.addi %acc, %20 : i32 + linalg.yield %21 : i32 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 1, + reduction = [0, 0, 0, 0, 128], + workgroup = [1, 1, 320, 32, 0]}>, + translation_info = #iree_codegen.translation_info}> + > -> !transform.any_param + transform.yield %cont, %config : !transform.any_op, !transform.any_param + } + transform.named_sequence @match_matmul_like_Bx10x4096x64x640_i8xi8xi32(%cont: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont { @@ -302,6 +369,10 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran , @match_matmul_like_Bx10x4096x64x640_i8xi8xi32 -> @apply_op_config , @match_matmul_like_Bx20x64x64x2048_i8xi8xi32 -> @apply_op_config + // Transpose-V generated contraction. + , @match_matmul_like_Bx20x64x1024x1280_i8xi8xi32 -> @apply_op_config + , @match_matmul_like_Bx20x64x64x2048_transposev_i8xi8xi32 -> @apply_op_config + // TUNING_MATCH_END DO NOT REMOVE : (!transform.any_op) -> (!transform.any_op) transform.yield diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h index aec6bd704e5d..8bf84cab2574 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h @@ -19,6 +19,12 @@ void populateFuseLinalgExtOpsWithTransposes( RewritePatternSet &patterns, const linalg::ControlFusionFn &controlFusionFn); +/// Bubble up transpose-like ops from LinalgExt ops (only `AttentionOp` +/// supported). +void populateBubbleTransposeFromLinalgExtOps( + RewritePatternSet &patterns, + const linalg::ControlFusionFn &controlFusionFn); + /// Helper struct to hold the results of collapsing an operation. struct CollapseResult { SmallVector results; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp index 2d158d54014a..bcc94ec951c0 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" +#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h" #include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -101,6 +102,103 @@ struct FuseTransposeWithAttentionOp final private: linalg::ControlFusionFn controlFn; }; + +// Bubbles transpose-V out of attention to expose the more performant +// attention-transposeV. +struct BubbleTransposeVFromAttentionOp + : public OpRewritePattern { + BubbleTransposeVFromAttentionOp(MLIRContext *context, + linalg::ControlFusionFn controlFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFn(controlFn) {} + + LogicalResult matchAndRewrite(LinalgExt::AttentionOp attentionOp, + PatternRewriter &rewriter) const override { + // Only checking for V because we are only bubbling transpose-V. + OpOperand *valueOpOperand = &attentionOp.getValueMutable(); + if (controlFn && !controlFn(valueOpOperand)) { + return rewriter.notifyMatchFailure( + attentionOp, "Expected attentionOp and producer of V to be non-null " + "and outside dispatch."); + } + // Extract Attention indexing information. + AffineMap qMap = attentionOp.getQueryMap(); + AffineMap kMap = attentionOp.getKeyMap(); + AffineMap vMap = attentionOp.getValueMap(); + AffineMap oMap = attentionOp.getOutputMap(); + FailureOr maybeOpInfo = + AttentionOpDetail::get(qMap, kMap, vMap, oMap); + if (failed(maybeOpInfo)) { + return failure(); + } + + // Only handle single dim for K2 and N for now. + if (maybeOpInfo->getK2Dims().size() != 1 || + maybeOpInfo->getNDims().size() != 1) { + return failure(); + } + // Check that V has standard map/non transposed V. + AffineExpr k2Dim = + rewriter.getAffineDimExpr(maybeOpInfo->getK2Dims().back()); + AffineExpr nDim = rewriter.getAffineDimExpr(maybeOpInfo->getNDims().back()); + int64_t vRank = vMap.getNumResults(); + // TODO: This check is quite conservative, in the future we should simply + // do vMap.getResultPosition(k2Dim) > vMap.getResultPosition(nDim). + if (vMap.getResult(vRank - 1) != nDim || + vMap.getResult(vRank - 2) != k2Dim) { + return failure(); + } + + // Get dimension positions to prepare for transpose. + std::optional maybeK2Pos = vMap.getResultPosition(k2Dim); + std::optional maybeNPos = vMap.getResultPosition(nDim); + assert(maybeK2Pos.has_value() && maybeNPos.has_value() && + "Expected K2 dim and N dim to be in V-map."); + int64_t k2Pos = maybeK2Pos.value(); + int64_t nPos = maybeNPos.value(); + SmallVector perm = llvm::to_vector(llvm::seq(0, vRank)); + std::swap(perm[k2Pos], perm[nPos]); + + // Expose transposeOp for V. + Location loc = attentionOp.getLoc(); + Value value = attentionOp.getValue(); + auto valueType = dyn_cast(value.getType()); + auto valueElType = valueType.getElementType(); + SmallVector transVShape = + tensor::getMixedSizes(rewriter, loc, value); + applyPermutationToVector(transVShape, perm); + Value initTransV = + rewriter.create(loc, transVShape, valueElType) + .getResult(); + Value transposeV = + rewriter.create(loc, value, initTransV, perm) + ->getResult(0); + + // Generate transpose V map. + SmallVector newExprs = + applyPermutation(vMap.getResults(), perm); + AffineMap transposedVMap = + AffineMap::get(vMap.getNumDims(), vMap.getNumSymbols(), newExprs, + rewriter.getContext()); + + // Modify attention to have transposed V inputs and mapping. + int64_t valueIndex = valueOpOperand->getOperandNumber(); + rewriter.modifyOpInPlace(attentionOp, [&]() { + SmallVector newIndexingMaps = + attentionOp.getIndexingMapsArray(); + newIndexingMaps[valueIndex] = transposedVMap; + attentionOp.setIndexingMapsAttr( + rewriter.getAffineMapArrayAttr(newIndexingMaps)); + attentionOp.setOperand(valueIndex, transposeV); + }); + return success(); + } + +private: + linalg::ControlFusionFn controlFn; +}; + } // namespace void populateFuseLinalgExtOpsWithTransposes( @@ -110,4 +208,11 @@ void populateFuseLinalgExtOpsWithTransposes( controlFusionFn); } +void populateBubbleTransposeFromLinalgExtOps( + RewritePatternSet &patterns, + const linalg::ControlFusionFn &controlFusionFn) { + patterns.add(patterns.getContext(), + controlFusionFn); +} + } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp b/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp index 41db09f07a16..3c1a783ecba3 100644 --- a/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp @@ -104,7 +104,6 @@ struct GatherFusionPattern final : public OpRewritePattern { void ElementwiseOpFusionPass::runOnOperation() { MLIRContext *context = &getContext(); - RewritePatternSet fusionPatterns(context); // Only fuse operations where all uses of the producer are generic // operations. If an operation is used in a named op, it will be computed // anyway, so the consumers can just use that value. @@ -135,24 +134,35 @@ void ElementwiseOpFusionPass::runOnOperation() { return areFusableAsElementwiseOps(context, fusedOperand, fuseMultiReduction); }; - linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, + + RewritePatternSet linalgFusionPatterns(context); + linalg::populateElementwiseOpsFusionPatterns(linalgFusionPatterns, fuseElementwiseOpsControlFn); + GreedyRewriteConfig rewriteConfig; + rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit; + if (failed(applyPatternsAndFoldGreedily( + getOperation(), std::move(linalgFusionPatterns), rewriteConfig))) { + getOperation()->emitOpError( + "Failed to fuse elementwise ops with upstream patterns."); + return signalPassFailure(); + } + + // Try fuse with linalgExt patterns. linalg::ControlFusionFn foldTransposeControlFn = [](OpOperand *fusedOperand) { Operation *producer = fusedOperand->get().getDefiningOp(); Operation *consumer = fusedOperand->getOwner(); return IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer}); }; + RewritePatternSet linalgExtFusionPatterns(context); IREE::LinalgExt::populateFuseLinalgExtOpsWithTransposes( - fusionPatterns, foldTransposeControlFn); - fusionPatterns.insert(context); - - GreedyRewriteConfig rewriteConfig; - rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit; + linalgExtFusionPatterns, foldTransposeControlFn); + linalgExtFusionPatterns.insert(context); if (failed(applyPatternsAndFoldGreedily( - getOperation(), std::move(fusionPatterns), rewriteConfig))) { - getOperation()->emitOpError("Failed to perform elementwise operations"); + getOperation(), std::move(linalgExtFusionPatterns), rewriteConfig))) { + getOperation()->emitOpError( + "Failed to fuse elementwise ops with linalgExt patterns."); return signalPassFailure(); } } diff --git a/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir index 8b556a03835d..096c882ab219 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir @@ -207,3 +207,54 @@ util.func public @fuse_generic_gather2( // CHECK-NEXT: %[[RES3:[a-zA-Z0-9]+]] = arith.mulf %[[RES]], %[[RES]] : f32 // CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32 // CHECK-NEXT: linalg.yield %[[RES4]] : f32 + +util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf16>, %k: tensor<2x10x4096x64xf16>, %quantized_v: tensor<2x10x4096x64xi32>, %quant_offset: tensor<10x64xi32>, %quant_scale: tensor<10x64xf32>, %scale: f16) -> tensor<2x10x4096x64xf16> { + // Dequantize int-quantization of V + %init_dequant = tensor.empty() : tensor<2x10x4096x64xf16> + %v = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%quantized_v, %quant_offset, %quant_scale : tensor<2x10x4096x64xi32>, tensor<10x64xi32>, tensor<10x64xf32>) outs(%init_dequant : tensor<2x10x4096x64xf16>) { + ^bb0(%in: i32, %in_0: i32, %in_1: f32, %out: f16): + %19 = arith.addi %in, %in_0 : i32 + %20 = arith.sitofp %19 : i32 to f32 + %21 = arith.mulf %20, %in_1 : f32 + %22 = arith.truncf %21 : f32 to f16 + linalg.yield %22 : f16 + } -> tensor<2x10x4096x64xf16> + + // Transpose-V + %init_transpose = tensor.empty() : tensor<2x10x64x4096xf16> + %transpose_v = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%v : tensor<2x10x4096x64xf16>) outs(%init_transpose : tensor<2x10x64x4096xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<2x10x64x4096xf16> + + // Attention-Transpose-V + %init_attention = tensor.empty() : tensor<2x10x4096x64xf16> + %attention = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%q, %k, %transpose_v, %scale : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>, f16) outs(%init_attention : tensor<2x10x4096x64xf16>) { + ^bb0(%score: f16): + iree_linalg_ext.yield %score: f16 + } -> tensor<2x10x4096x64xf16> + util.return %attention : tensor<2x10x4096x64xf16> +} + +// CHECK-LABEL: util.func public @fuse_transpose_attention_to_producer +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[ARG4:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[ARG5:[A-Za-z0-9]+]]: f16 +// CHECK: %[[DEQUANT_V:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>] +// CHECK-SAME: ins(%[[ARG2]], %[[ARG3]], %[[ARG4]] +// CHECK: %[[RESULT:.+]] = iree_linalg_ext.attention +// CHECK-SAME: indexing_maps = +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[DEQUANT_V]], %[[ARG5]] diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp index 846233841732..265ddbbc5890 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "llvm/Support/Debug.h" @@ -1087,6 +1088,15 @@ void PropagateLinalgTransposePass::runOnOperation() { linalg::populateFoldReshapeOpsByExpansionPatterns(bubblingPatterns, reshapePropagationFn); linalg::FillOp::getCanonicalizationPatterns(bubblingPatterns, context); + linalg::ControlFusionFn bubbleTransposeControlFn = + [](OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + Operation *consumer = fusedOperand->getOwner(); + + return IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer}); + }; + IREE::LinalgExt::populateBubbleTransposeFromLinalgExtOps( + bubblingPatterns, bubbleTransposeControlFn); bubblingPatterns.insert( context, enableAggressivePropagation); bubblingPatterns.insert(context); diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir index 6b9571666808..16f37473eb47 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir @@ -665,3 +665,47 @@ util.func public @bubble_transpose_to_broadcast_elementwise(%arg0: tensor<2x3x4x // BUBBLE-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x4xf32> // BUBBLE: arith.addf // BUBBLE: util.return %[[ELEM]] : tensor<3x4x2xf32> + +// ----- + +util.func public @bubble_transpose_v_from_attention(%q: tensor<2x10x4096x64xf16>, %k: tensor<2x10x4096x64xf16>, %quantized_v: tensor<2x10x4096x64xi32>, %quant_offset: tensor<10x64xi32>, %quant_scale: tensor<10x64xf32>, %scale: f16) -> tensor<2x10x4096x64xf16> { + // Dequantize int-quantization of V + %init_dequant = tensor.empty() : tensor<2x10x4096x64xf16> + %v = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%quantized_v, %quant_offset, %quant_scale : tensor<2x10x4096x64xi32>, tensor<10x64xi32>, tensor<10x64xf32>) outs(%init_dequant : tensor<2x10x4096x64xf16>) { + ^bb0(%in: i32, %in_0: i32, %in_1: f32, %out: f16): + %19 = arith.addi %in, %in_0 : i32 + %20 = arith.sitofp %19 : i32 to f32 + %21 = arith.mulf %20, %in_1 : f32 + %22 = arith.truncf %21 : f32 to f16 + linalg.yield %22 : f16 + } -> tensor<2x10x4096x64xf16> + + // Attention with transposed V + %init_attention = tensor.empty() : tensor<2x10x4096x64xf16> + %attention = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%q, %k, %v, %scale : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, f16) outs(%init_attention : tensor<2x10x4096x64xf16>) { + ^bb0(%score: f16): + iree_linalg_ext.yield %score: f16 + } -> tensor<2x10x4096x64xf16> + util.return %attention : tensor<2x10x4096x64xf16> +} + + +// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> +// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)> +// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> + +// CHECK-LABEL: util.func public @bubble_transpose_v_from_attention( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x10x4096x64xf16>, %[[ARG1:.*]]: tensor<2x10x4096x64xf16>, %[[ARG2:.*]]: tensor<2x10x4096x64xi32>, +// CHECK-SAME: %[[ARG3:.*]]: tensor<10x64xi32>, %[[ARG4:.*]]: tensor<10x64xf32>, %[[ARG5:.*]]: f16) -> tensor<2x10x4096x64xf16> { +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x10x4096x64xf16> +// CHECK: %[[DEQUANT_V:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG2]], %[[ARG3]], %[[ARG4]] : tensor<2x10x4096x64xi32>, tensor<10x64xi32>, tensor<10x64xf32>) +// CHECK-SAME: outs(%{{.*}} : tensor<2x10x4096x64xf16>) +// CHECK: %[[TRANS_V:.*]] = linalg.transpose ins(%[[DEQUANT_V]] : tensor<2x10x4096x64xf16>) outs({{.*}} : tensor<2x10x64x4096xf16>) permutation = [0, 1, 3, 2] +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention +// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[TRANS_V]], %[[ARG5]] : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>, f16) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x10x4096x64xf16>) +// CHECK: util.return %[[ATTN]] : tensor<2x10x4096x64xf16>