From 16e2931a2ccbe8d714f938942dc272428f8b31ca Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Sun, 27 Aug 2023 11:00:30 -0400 Subject: [PATCH] [Flow] Raise batch_matmul(a, transpose(b)) to batch_matmul_transpose_b (#14847) Adds a similar raising pattern as that for matmul(a, transpose(b)). --- .../Flow/Transforms/RaiseSpecialOps.cpp | 56 ++++++++++++++++--- .../Transforms/test/raise_special_ops.mlir | 24 ++++++++ 2 files changed, 71 insertions(+), 9 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp index 8d4c4d9cf681..03fa3f8e6b67 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp @@ -28,23 +28,30 @@ namespace Flow { namespace { -// Method to match a transpose operation. -static bool match2DTranspose(linalg::LinalgOp genericOp) { +// Method to match a transpose operation on the two most minor dimensions of the +// specified rank. +static bool matchInner2DTranspose(linalg::LinalgOp genericOp, unsigned rank) { + // Only makes sense for minimum rank 2. + if (rank < 2) { + return false; + } if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) { return false; } - // Check only for 2D ops. - if (genericOp.getNumLoops() != 2 || + // Check only for ops of the specified rank. + if (genericOp.getNumLoops() != rank || genericOp.getNumLoops() != genericOp.getNumParallelLoops()) { return false; } // Check for transpose map. - AffineExpr d0, d1; + SmallVector exprList(rank); MLIRContext *context = genericOp.getContext(); - bindDims(context, d0, d1); + bindDimsList(context, MutableArrayRef{exprList}); + SmallVector transposeExprList(exprList); + std::swap(transposeExprList[rank - 1], transposeExprList[rank - 2]); SmallVector expectedMaps = { - AffineMap::get(2, 0, {d0, d1}, context), - AffineMap::get(2, 0, {d1, d0}, context)}; + AffineMap::get(rank, 0, exprList, context), + AffineMap::get(rank, 0, transposeExprList, context)}; if (genericOp.getIndexingMapsArray() != expectedMaps) { return false; } @@ -70,7 +77,21 @@ std::optional matchATransposeBMatmul(linalg::LinalgOp matmulOp) { } auto rhs = matmulOp.getDpsInputOperand(1); auto genericOp = rhs->get().getDefiningOp(); - if (genericOp && match2DTranspose(genericOp)) { + if (genericOp && matchInner2DTranspose(genericOp, 2)) { + return genericOp.getDpsInputOperand(0)->get(); + } + return std::nullopt; +} + +// Method to match a linalg.batch_matmul(a, linalg.transpose(b)). Returns `b` on +// success. +std::optional matchATransposeBBatchMatmul(linalg::LinalgOp bmmOp) { + if (!isa(bmmOp.getOperation())) { + return std::nullopt; + } + auto rhs = bmmOp.getDpsInputOperand(1); + auto genericOp = rhs->get().getDefiningOp(); + if (genericOp && matchInner2DTranspose(genericOp, 3)) { return genericOp.getDpsInputOperand(0)->get(); } return std::nullopt; @@ -361,6 +382,8 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase { SmallVector> softmaxRoots; SmallVector> transposeMatmulRoots; + SmallVector> + transposeBatchMatmulRoots; SmallVector> genericFills; getOperation()->walk([&](linalg::LinalgOp op) { { @@ -376,6 +399,10 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase { transposeMatmulRoots.push_back(std::make_pair( cast(op.getOperation()), newRhs.value())); } + if (std::optional newRhs = matchATransposeBBatchMatmul(op)) { + transposeBatchMatmulRoots.push_back(std::make_pair( + cast(op.getOperation()), newRhs.value())); + } if (std::optional fillInput = matchGenericFill(op)) { genericFills.push_back( std::make_pair(cast(op), fillInput.value())); @@ -402,6 +429,17 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase { rewriter.replaceOpWithNewOp( matmulOp, ValueRange{lhs, newRhs}, ValueRange{init}, attrs); } + for (std::pair aTransposeBBatchMatmul : + transposeBatchMatmulRoots) { + auto bmmOp = aTransposeBBatchMatmul.first; + Value lhs = bmmOp.getDpsInputOperand(0)->get(); + auto newRhs = aTransposeBBatchMatmul.second; + Value init = bmmOp.getDpsInitOperand(0)->get(); + rewriter.setInsertionPoint(bmmOp); + SmallVector attrs = getPrunedAttributeList(bmmOp); + rewriter.replaceOpWithNewOp( + bmmOp, ValueRange{lhs, newRhs}, ValueRange{init}, attrs); + } for (std::pair genericFill : genericFills) { auto genericOp = genericFill.first; Value fillInput = genericFill.second; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir index 54835d52f745..5e238aaebe88 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir @@ -187,6 +187,30 @@ func.func @aTransposeBMatmul(%arg0 : tensor<10x20xf32>, // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : // CHECK: return %[[RESULT]] +func.func @aTransposeBBatchMatmul(%arg0 : tensor<5x10x20xf32>, + %arg1 : tensor<5x40x20xf32>) -> tensor<5x10x40xf32> { + %0 = tensor.empty() : tensor<5x20x40xf32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg1 : tensor<5x40x20xf32>) outs(%0 : tensor<5x20x40xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + linalg.yield %b0 : f32 + } -> tensor<5x20x40xf32> + %2 = tensor.empty() : tensor<5x10x40xf32> + %3 = arith.constant 0.0 : f32 + %4 = linalg.fill ins(%3 : f32) outs(%2 : tensor<5x10x40xf32>) -> tensor<5x10x40xf32> + %5 = linalg.batch_matmul ins(%arg0, %1 : tensor<5x10x20xf32>, tensor<5x20x40xf32>) + outs(%4 : tensor<5x10x40xf32>) -> tensor<5x10x40xf32> + return %5 : tensor<5x10x40xf32> +} +// CHECK-LABEL: func @aTransposeBBatchMatmul +// CHECK-SAME: %[[ARG0:.+]]: tensor<5x10x20xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<5x40x20xf32> +// CHECK: %[[RESULT:.+]] = linalg.batch_matmul_transpose_b +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CHECK: return %[[RESULT]] + func.func @generic_fill(%arg0: tensor) -> tensor<1x1x?x?xf32> { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index