Skip to content

Commit

Permalink
[Codegen][GPU] Add producer fusion pattern to loop fusion and hoistin…
Browse files Browse the repository at this point in the history
…g pass (iree-org#18118)

This adds an additional step to the FuseAndHoistParallelLoops pass to
allow further fusion of producers after doing consumer fusion. For
example,

```
scf.forall | transpose
  |          |
  v          v
   linalg.add
```

Where consumer fusion will fuse the add, but we need another step of
producer fusion to fuse in the transpose.

Also removes hal.binding.subspan ops from the tests.
  • Loading branch information
qedawkins authored Aug 7, 2024
1 parent 7cf0e26 commit b76f89c
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -114,6 +115,34 @@ struct FuseTilableDestinationProducers final : OpRewritePattern<scf::ForallOp> {
}
};

struct FuseTilableSliceProducers final
: OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
if (sliceOp->use_empty()) {
return failure();
}
auto tilableProducer = sliceOp.getSource().getDefiningOp<TilingInterface>();
if (!tilableProducer) {
return failure();
}

auto parentForall = sliceOp->getParentOfType<scf::ForallOp>();
if (!parentForall) {
return failure();
}

SmallVector<LoopLikeOpInterface> loops = {parentForall};
std::optional<scf::SCFFuseProducerOfSliceResult> fusionResult =
mlir::scf::tileAndFuseProducerOfSlice(rewriter, sliceOp, loops);
if (!fusionResult) {
return failure();
}
return success();
}
};

struct FuseTilableForallConsumers final
: OpInterfaceRewritePattern<TilingInterface> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
Expand Down Expand Up @@ -192,6 +221,19 @@ void FuseAndHoistParallelLoopsPass::runOnOperation() {
return signalPassFailure();
}
}

// Finally try to do any new producer fusions.
{
RewritePatternSet patterns(context);
patterns.add<FuseTilableDestinationProducers>(context);
patterns.add<FuseTilableSliceProducers>(context);
tensor::populateFoldTensorEmptyPatterns(patterns);
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
}

} // namespace mlir::iree_compiler::IREE::GPU
Original file line number Diff line number Diff line change
@@ -1,27 +1,14 @@
// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-fuse-and-hoist-parallel-loops))' --split-input-file | FileCheck %s

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
#map = affine_map<(d0) -> (d0 * 2)>
#map1 = affine_map<(d0) -> (d0 * 4)>
#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
#map3 = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
#map4 = affine_map<(d0) -> (d0 * 16)>
func.func @forall_fuse_then_hoist() {
func.func @forall_fuse_then_hoist(%3: tensor<128x128xf16>, %4: tensor<128x128xf16>, %5: tensor<128x128xf32>) -> tensor<128x128xf32> {
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>> -> tensor<128x128xf32>
%6 = tensor.empty() : tensor<128x4xf16>
%7 = tensor.empty() : tensor<4x128xf16>
%8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %5) -> (tensor<128x128xf32>) {
Expand Down Expand Up @@ -62,8 +49,7 @@ func.func @forall_fuse_then_hoist() {
} {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
scf.yield %11 : tensor<128x128xf32>
}
flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
return
return %8 : tensor<128x128xf32>
}

// CHECK-LABEL: func @forall_fuse_then_hoist
Expand All @@ -72,30 +58,19 @@ func.func @forall_fuse_then_hoist() {
// CHECK: scf.yield {{.*}} : tensor<16x16xf32>
// CHECK: scf.forall.in_parallel
// CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]]
// CHECK: flow.dispatch.tensor.store %[[OUTER_PARALLEL]]
// CHECK: return %[[OUTER_PARALLEL]]

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
#map = affine_map<(d0) -> (d0 * 2)>
#map1 = affine_map<(d0) -> (d0 * 4)>
#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
#map3 = affine_map<(d0) -> (d0 * 16)>
func.func @forall_fuse_then_hoist_mixed_mappings() {
func.func @forall_fuse_then_hoist_mixed_mappings(%3: tensor<128x128xf16>, %5: tensor<128x128xf32>) -> tensor<128x128xf32> {
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.0> : tensor<4x128xf16>
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>> -> tensor<128x128xf32>
%6 = tensor.empty() : tensor<128x4xf16>
%7 = tensor.empty() : tensor<4x128xf16>
%8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %5) -> (tensor<128x128xf32>) {
Expand Down Expand Up @@ -124,8 +99,7 @@ func.func @forall_fuse_then_hoist_mixed_mappings() {
} {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
scf.yield %11 : tensor<128x128xf32>
}
flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
return
return %8 : tensor<128x128xf32>
}

// CHECK-LABEL: func @forall_fuse_then_hoist_mixed_mappings
Expand All @@ -135,31 +109,19 @@ func.func @forall_fuse_then_hoist_mixed_mappings() {
// CHECK: scf.forall.in_parallel
// CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]]
// CHECK-NOT: scf.forall
// CHECK: flow.dispatch.tensor.store %[[OUTER_PARALLEL]]
// CHECK: return %[[OUTER_PARALLEL]]

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
#map = affine_map<(d0) -> (d0 * 2)>
#map1 = affine_map<(d0) -> (d0 * 4)>
#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
#map3 = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
#map4 = affine_map<(d0) -> (d0 * 16)>
func.func @forall_fuse_then_hoist_with_fill() {
func.func @forall_fuse_then_hoist_with_fill(%3: tensor<128x128xf16>, %4: tensor<128x128xf16>) -> tensor<128x128xf32> {
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
%empty = tensor.empty() : tensor<128x128xf32>
%cst = arith.constant 0.0 : f32
%5 = linalg.fill ins(%cst : f32) outs(%empty : tensor<128x128xf32>) -> tensor<128x128xf32>
Expand Down Expand Up @@ -203,8 +165,7 @@ func.func @forall_fuse_then_hoist_with_fill() {
} {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
scf.yield %11 : tensor<128x128xf32>
}
flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
return
return %8 : tensor<128x128xf32>
}

// CHECK-LABEL: func @forall_fuse_then_hoist_with_fill
Expand All @@ -214,24 +175,14 @@ func.func @forall_fuse_then_hoist_with_fill() {
// CHECK: scf.yield {{.*}} : tensor<16x16xf32>
// CHECK: scf.forall.in_parallel
// CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]]
// CHECK: flow.dispatch.tensor.store %[[OUTER_PARALLEL]]
// CHECK: return %[[OUTER_PARALLEL]]

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
func.func @multi_hoist_and_fuse_trailing_stuff() {
func.func @multi_hoist_and_fuse_trailing_stuff(%2: tensor<128x128xf16>) -> tensor<128x128xf16> {
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<128x128xf16>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
%empty = tensor.empty() : tensor<128x128xf16>
%8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %empty) -> (tensor<128x128xf16>) {
%9 = scf.forall (%arg2, %arg3) in (2, 2) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf16>) {
Expand All @@ -252,8 +203,7 @@ func.func @multi_hoist_and_fuse_trailing_stuff() {
}
%transpose = linalg.transpose ins(%8: tensor<128x128xf16>) outs(%empty: tensor<128x128xf16>) permutation = [1, 0]
%ceil = linalg.ceil ins(%transpose: tensor<128x128xf16>) outs(%empty: tensor<128x128xf16>) -> tensor<128x128xf16>
flow.dispatch.tensor.store %ceil, %1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf16> -> !flow.dispatch.tensor<readwrite:tensor<128x128xf16>>
return
return %ceil : tensor<128x128xf16>
}

// CHECK-LABEL: func @multi_hoist_and_fuse_trailing_stuff
Expand All @@ -265,4 +215,48 @@ func.func @multi_hoist_and_fuse_trailing_stuff() {
// CHECK: linalg.ceil ins(%[[T]] : tensor<4x2xf16>) {{.*}} -> tensor<4x2xf16>
// CHECK: scf.forall.in_parallel
// CHECK: scf.forall.in_parallel
// CHECK: flow.dispatch.tensor.store
// CHECK: return

// -----

func.func @multi_hoist_and_fuse_trailing_with_producer_fusion(%2: tensor<128x128xf16>, %3: tensor<128x128xf16>) -> tensor<128x128xf16> {
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%empty = tensor.empty() : tensor<128x128xf16>
%8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %empty) -> (tensor<128x128xf16>) {
%9 = scf.forall (%arg2, %arg3) in (2, 2) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf16>) {
%extracted_slice = tensor.extract_slice %arg4[%arg2, %arg3] [64, 64] [1, 1] : tensor<128x128xf16> to tensor<64x64xf16>
%10 = scf.forall (%arg5, %arg6) in (32, 16) shared_outs(%arg7 = %extracted_slice) -> (tensor<64x64xf16>) {
%extracted_slice_1 = tensor.extract_slice %2[%arg5, %arg6] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16>
%extracted_slice_2 = tensor.extract_slice %arg7[%arg5, %arg6] [2, 4] [1, 1] : tensor<64x64xf16> to tensor<2x4xf16>
%16 = linalg.copy ins(%extracted_slice_1 : tensor<2x4xf16>) outs(%extracted_slice_2 : tensor<2x4xf16>) -> tensor<2x4xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg7[%arg5, %arg6] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<64x64xf16>
}
} {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
scf.forall.in_parallel {
tensor.parallel_insert_slice %10 into %arg4[%arg2, %arg3] [64, 64] [1, 1] : tensor<64x64xf16> into tensor<128x128xf16>
}
} {mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>]}
scf.yield %9 : tensor<128x128xf16>
}
%transpose_input = linalg.transpose ins(%3: tensor<128x128xf16>) outs(%empty: tensor<128x128xf16>) permutation = [1, 0]
%add = linalg.add
ins(%8, %transpose_input : tensor<128x128xf16>, tensor<128x128xf16>)
outs(%empty: tensor<128x128xf16>) -> tensor<128x128xf16>
return %add : tensor<128x128xf16>
}

// CHECK-LABEL: func @multi_hoist_and_fuse_trailing_with_producer_fusion
// CHECK-SAME: %[[I0:[A-Za-z0-9]+]]: tensor<128x128xf16>
// CHECK-SAME: %[[I1:[A-Za-z0-9]+]]: tensor<128x128xf16>
// CHECK: scf.forall
// CHECK: scf.forall
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} -> (tensor<2x4xf16>)
// CHECK: linalg.copy
// CHECK: %[[T:.+]] = linalg.transpose ins(%{{.*}} : tensor<4x2xf16>)
// CHECK: linalg.add ins(%[[LOOP]], %[[T]] : tensor<2x4xf16>, tensor<2x4xf16>) {{.*}} -> tensor<2x4xf16>
// CHECK: scf.forall.in_parallel
// CHECK: scf.forall.in_parallel
// CHECK: return

0 comments on commit b76f89c

Please sign in to comment.