Skip to content

Commit

Permalink
[Dispatch] Add pattern to bubble expand through extract 1/2 (iree-org…
Browse files Browse the repository at this point in the history
…#19325)

This is the 1/2 changes needed to reland
iree-org#18857 (with an open PR
iree-org#19113).


Adds pattern to bubble up expand shape through extract slice. i.e
`expand(extract)` to `extract(expand)`. This only supports the case
where the expanded dimensions are not modified by the extract slice and
there are no dynamic dimensions.

This is important because `tensor.expand_shape` ops _cannot be cloned_
while `tensor.extract_slice` ops _can be cloned_. So, if the
`expand_shape` gets stuck on the bottom of the `extract_slice` it will
block it from being cloned and the `extract_slice` will have to be put
into its own dispatch.

---------

Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 authored Dec 3, 2024
1 parent a30a419 commit 529cd89
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,81 @@ struct BubbleUpExpandShapesPass final
void runOnOperation() override;
};

/// Bubbles a `tensor.expand_shape` op through a `tensor.extract_slice` op. This
/// pattern only gets applied when the `extract_slice` doesn't modify dimensions
/// that are expanded by the `expand_shape` and when the `extract_slice` is
/// completely static.
/// TODO: move this upstream with other tensor bubbling patterns.
struct BubbleExpandThroughExtract final
: public OpRewritePattern<tensor::ExpandShapeOp> {

using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
auto extractOp = expandOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
if (!extractOp) {
return failure();
}

auto srcType = extractOp.getSourceType();
auto extractedType = extractOp.getType();
auto expandedType = expandOp.getType();

if (srcType.getRank() != extractedType.getRank()) {
return rewriter.notifyMatchFailure(
extractOp, "Rank reducing extract_slice not supported");
}

if (!srcType.hasStaticShape() || !extractedType.hasStaticShape() ||
!expandedType.hasStaticShape()) {
return failure();
}

auto reassoc = expandOp.getReassociationIndices();
for (auto i : llvm::seq<uint64_t>(0, extractedType.getRank())) {
if (reassoc[i].size() == 1) {
continue;
}

if (srcType.getShape()[i] != extractedType.getShape()[i]) {
return rewriter.notifyMatchFailure(
extractOp, "Extract modifies the expanded dimension");
}
}

SmallVector<int64_t> newExpandShape;
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
SmallVector<int64_t> strides;
for (auto [inDim, outDims] : llvm::enumerate(reassoc)) {
if (outDims.size() == 1) {
newExpandShape.push_back(srcType.getShape()[inDim]);
offsets.push_back(extractOp.getStaticOffsets()[inDim]);
sizes.push_back(extractOp.getStaticSizes()[inDim]);
strides.push_back(extractOp.getStaticStrides()[inDim]);
} else {
for (auto outDim : outDims) {
newExpandShape.push_back(expandedType.getShape()[outDim]);
offsets.push_back(0);
sizes.push_back(expandedType.getShape()[outDim]);
strides.push_back(1);
}
}
}

Type newExpandType =
RankedTensorType::get(newExpandShape, expandedType.getElementType());
auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
expandOp.getLoc(), newExpandType, extractOp.getSource(), reassoc);

rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
expandOp, expandedType, newExpand, ValueRange{}, ValueRange{},
ValueRange{}, offsets, sizes, strides);
return success();
}
};

} // namespace

void BubbleUpExpandShapesPass::runOnOperation() {
Expand Down Expand Up @@ -87,6 +162,7 @@ void BubbleUpExpandShapesPass::runOnOperation() {
// Add patterns to do some additional cleanup (on top of canonicalizations
// that can be done later) of reshape ops.
tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns);
bubbleExpandShapePatterns.insert<BubbleExpandThroughExtract>(context);

GreedyRewriteConfig rewriteConfig;
rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ iree_lit_test_suite(
"form_dispatch_regions.mlir",
"dispatch_linalg_on_tensors.mlir",
"convert_region_to_workgroups.mlir",
"bubble_up_expand_shapes.mlir",
"bubble_up_extract_slice.mlir",
"form_dispatch_workgroups.mlir",
"dispatch_linalg_ext_fusion.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_lit_test_suite(
lit
SRCS
"attention_fuse_by_expansion.mlir"
"bubble_up_expand_shapes.mlir"
"bubble_up_extract_slice.mlir"
"clone_producers_into_dispatch_regions.mlir"
"collapse_dimensions.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-bubble-up-expand-shapes))" %s | FileCheck %s

util.func public @bubbble_expand_through_extract(%arg0 : tensor<2x4096x5120xf16>) -> (tensor<2x64x64x2560xf16>) {
%extracted_slice_237 = tensor.extract_slice %arg0[0, 0, 0] [2, 4096, 2560] [1, 1, 1] : tensor<2x4096x5120xf16> to tensor<2x4096x2560xf16>
%expanded_239 = tensor.expand_shape %extracted_slice_237 [[0], [1, 2], [3]] output_shape [2, 64, 64, 2560] : tensor<2x4096x2560xf16> into tensor<2x64x64x2560xf16>
util.return %expanded_239 : tensor<2x64x64x2560xf16>
}

// CHECK-LABEL: @bubbble_expand_through_extract
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[EXPAND]]

// -----

util.func public @unsupported_bubbble_expand_through_extract(%arg0 : tensor<2x4096x5120xf16>) -> (tensor<2x32x64x2560xf16>) {
%extracted_slice_237 = tensor.extract_slice %arg0[0, 0, 0] [2, 2048, 2560] [1, 1, 1] : tensor<2x4096x5120xf16> to tensor<2x2048x2560xf16>
%expanded_239 = tensor.expand_shape %extracted_slice_237 [[0], [1, 2], [3]] output_shape [2, 32, 64, 2560] : tensor<2x2048x2560xf16> into tensor<2x32x64x2560xf16>
util.return %expanded_239 : tensor<2x32x64x2560xf16>
}

// CHECK-LABEL: @unsupported_bubbble_expand_through_extract
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[EXTRACT]]
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ util.func public @bubble_up_extract_with_use(%arg0 : tensor<1024x7x7x2xi8>) -> (
// CHECK-SAME: ins(%[[EXTRACT0]] : tensor<1024x7x7xi8>)
// CHECK: util.return %[[GENERIC1]], %[[GENERIC0]]

// -----

util.func public @bubble_up_extract_fill_multi_use() -> tensor<2x320x130x130xf8E4M3FNUZ> {
%cst_1 = arith.constant 1.000000e+00 : f8E4M3FNUZ
%cst_2 = arith.constant 2.000000e+00 : f8E4M3FNUZ
Expand Down

0 comments on commit 529cd89

Please sign in to comment.