Skip to content

Commit

Permalink
[DispatchCreation] Enable bubble up extract slice for `linalg.generic…
Browse files Browse the repository at this point in the history
…` op with a single use. (iree-org#19174)

For a `linalg.generic` -> `tensor.extract_slice` pattern where the
producer has a single use, this is always good to do.

TODO: This pattern could be generalized for any `LinalgOp`, but not done
now because of the rank-reduction slices need to fold the unit
dimensions.

Fixes iree-org#19173

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar authored Nov 18, 2024
1 parent c581951 commit df83f8e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ struct BubbleUpExtract : OpRewritePattern<tensor::ExtractSliceOp> {
"single result");
}

if (!IREE::LinalgExt::isBitExtendOp(genericOp)) {
if (!IREE::LinalgExt::isBitExtendOp(genericOp) && !genericOp->hasOneUse()) {
return rewriter.notifyMatchFailure(
sliceOp, "expected source to be dequantize-like");
sliceOp,
"expected source to be dequantize-like op or have a single use");
}

if (!sliceOp.hasUnitStride()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --split-input-file --iree-dispatch-creation-bubble-up-extract-slices --iree-flow-canonicalize %s | FileCheck %s
// RUN: iree-opt --split-input-file --iree-dispatch-creation-bubble-up-extract-slices --iree-flow-canonicalize --mlir-print-local-scope %s | FileCheck %s

util.func public @bubble_up_extract_rank_reduce(%arg0 : tensor<1024x7x7x2xi8>) -> tensor<1024x7x7xf32>{
%0 = tensor.empty() : tensor<1024x7x7x2xf32>
Expand Down Expand Up @@ -115,3 +115,27 @@ util.func public @bubble_up_extract_fill_multi_use() -> tensor<2x320x130x130xf8E
// CHECK-NOT: %[[SLICE:.+]] = tensor.extract_slice
// CHECK: %[[EMPTY2:.+]] = tensor.empty
// CHECK: %[[FILL3:.+]] = linalg.fill

// -----

func.func @bubble_up_extract_slice_single_use(%arg0: tensor<131072xi64>, %arg1: tensor<1x1x131072xi64>, %arg2: index) -> tensor<?x?xi1> {
%0 = tensor.empty() : tensor<1x1x131072x131072xi1>
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<131072xi64>, tensor<1x1x131072xi64>) outs(%0 : tensor<1x1x131072x131072xi1>) {
^bb0(%in: i64, %in_0: i64, %out: i1):
%2 = arith.cmpi sge, %in, %in_0 : i64
linalg.yield %2 : i1
} -> tensor<1x1x131072x131072xi1>
%extracted_slice = tensor.extract_slice %1[0, 0, 0, 0] [1, 1, %arg2, %arg2] [1, 1, 1, 1] : tensor<1x1x131072x131072xi1> to tensor<?x?xi1>
return %extracted_slice : tensor<?x?xi1>
}
// CHECK-LABEL: func @bubble_up_extract_slice_single_use
// CHECK-SAME: %[[ARG0:.+]]: tensor<131072xi64>
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x1x131072xi64>
// CHECK-SAME: %[[ARG2:.+]]: index
// CHECK-DAG: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK-DAG: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[ARG2]], %[[ARG2]])
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] :
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK: return %[[GENERIC]]

0 comments on commit df83f8e

Please sign in to comment.