From b76f89ca8ba1a9e4935575826f1de4a3777e9e94 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Wed, 7 Aug 2024 11:50:09 -0400 Subject: [PATCH] [Codegen][GPU] Add producer fusion pattern to loop fusion and hoisting pass (#18118) This adds an additional step to the FuseAndHoistParallelLoops pass to allow further fusion of producers after doing consumer fusion. For example, ``` scf.forall | transpose | | v v linalg.add ``` Where consumer fusion will fuse the add, but we need another step of producer fusion to fuse in the transpose. Also removes hal.binding.subspan ops from the tests. --- .../Transforms/FuseAndHoistParallelLoops.cpp | 42 +++++++ .../test/fuse_and_hoist_forall.mlir | 118 +++++++++--------- 2 files changed, 98 insertions(+), 62 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp index 895e9ee0e461..bd52dcdd02ba 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -114,6 +115,34 @@ struct FuseTilableDestinationProducers final : OpRewritePattern { } }; +struct FuseTilableSliceProducers final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const override { + if (sliceOp->use_empty()) { + return failure(); + } + auto tilableProducer = sliceOp.getSource().getDefiningOp(); + if (!tilableProducer) { + return failure(); + } + + auto parentForall = sliceOp->getParentOfType(); + if (!parentForall) { + return failure(); + } + + SmallVector loops = {parentForall}; + std::optional fusionResult = + mlir::scf::tileAndFuseProducerOfSlice(rewriter, sliceOp, loops); + if (!fusionResult) { + return failure(); + } + return success(); + } +}; + struct FuseTilableForallConsumers final : OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; @@ -192,6 +221,19 @@ void FuseAndHoistParallelLoopsPass::runOnOperation() { return signalPassFailure(); } } + + // Finally try to do any new producer fusions. + { + RewritePatternSet patterns(context); + patterns.add(context); + patterns.add(context); + tensor::populateFoldTensorEmptyPatterns(patterns); + scf::ForallOp::getCanonicalizationPatterns(patterns, context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } } } // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir index e889d821271b..a50c0e32eec3 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir @@ -1,27 +1,14 @@ // RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-fuse-and-hoist-parallel-loops))' --split-input-file | FileCheck %s -#pipeline_layout = #hal.pipeline.layout, - #hal.descriptor_set.binding<1, storage_buffer>, - #hal.descriptor_set.binding<2, storage_buffer> - ]> -]> #map = affine_map<(d0) -> (d0 * 2)> #map1 = affine_map<(d0) -> (d0 * 4)> #map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)> #map3 = affine_map<(d0)[s0] -> (d0 * 2 + s0)> #map4 = affine_map<(d0) -> (d0 * 16)> -func.func @forall_fuse_then_hoist() { +func.func @forall_fuse_then_hoist(%3: tensor<128x128xf16>, %4: tensor<128x128xf16>, %5: tensor<128x128xf32>) -> tensor<128x128xf32> { %c4 = arith.constant 4 : index %c128 = arith.constant 128 : index %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<128x128xf16> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<128x128xf16> - %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<128x128xf32> %6 = tensor.empty() : tensor<128x4xf16> %7 = tensor.empty() : tensor<4x128xf16> %8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %5) -> (tensor<128x128xf32>) { @@ -62,8 +49,7 @@ func.func @forall_fuse_then_hoist() { } {mapping = [#gpu.thread, #gpu.thread]} scf.yield %11 : tensor<128x128xf32> } - flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor> - return + return %8 : tensor<128x128xf32> } // CHECK-LABEL: func @forall_fuse_then_hoist @@ -72,30 +58,19 @@ func.func @forall_fuse_then_hoist() { // CHECK: scf.yield {{.*}} : tensor<16x16xf32> // CHECK: scf.forall.in_parallel // CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]] -// CHECK: flow.dispatch.tensor.store %[[OUTER_PARALLEL]] +// CHECK: return %[[OUTER_PARALLEL]] // ----- -#pipeline_layout = #hal.pipeline.layout, - #hal.descriptor_set.binding<1, storage_buffer>, - #hal.descriptor_set.binding<2, storage_buffer> - ]> -]> #map = affine_map<(d0) -> (d0 * 2)> #map1 = affine_map<(d0) -> (d0 * 4)> #map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)> #map3 = affine_map<(d0) -> (d0 * 16)> -func.func @forall_fuse_then_hoist_mixed_mappings() { +func.func @forall_fuse_then_hoist_mixed_mappings(%3: tensor<128x128xf16>, %5: tensor<128x128xf32>) -> tensor<128x128xf32> { %c4 = arith.constant 4 : index %c128 = arith.constant 128 : index %c0 = arith.constant 0 : index %cst = arith.constant dense<0.0> : tensor<4x128xf16> - %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<128x128xf16> - %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<128x128xf32> %6 = tensor.empty() : tensor<128x4xf16> %7 = tensor.empty() : tensor<4x128xf16> %8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %5) -> (tensor<128x128xf32>) { @@ -124,8 +99,7 @@ func.func @forall_fuse_then_hoist_mixed_mappings() { } {mapping = [#gpu.thread, #gpu.thread]} scf.yield %11 : tensor<128x128xf32> } - flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor> - return + return %8 : tensor<128x128xf32> } // CHECK-LABEL: func @forall_fuse_then_hoist_mixed_mappings @@ -135,31 +109,19 @@ func.func @forall_fuse_then_hoist_mixed_mappings() { // CHECK: scf.forall.in_parallel // CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]] // CHECK-NOT: scf.forall -// CHECK: flow.dispatch.tensor.store %[[OUTER_PARALLEL]] +// CHECK: return %[[OUTER_PARALLEL]] // ----- -#pipeline_layout = #hal.pipeline.layout, - #hal.descriptor_set.binding<1, storage_buffer>, - #hal.descriptor_set.binding<2, storage_buffer> - ]> -]> #map = affine_map<(d0) -> (d0 * 2)> #map1 = affine_map<(d0) -> (d0 * 4)> #map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)> #map3 = affine_map<(d0)[s0] -> (d0 * 2 + s0)> #map4 = affine_map<(d0) -> (d0 * 16)> -func.func @forall_fuse_then_hoist_with_fill() { +func.func @forall_fuse_then_hoist_with_fill(%3: tensor<128x128xf16>, %4: tensor<128x128xf16>) -> tensor<128x128xf32> { %c4 = arith.constant 4 : index %c128 = arith.constant 128 : index %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<128x128xf16> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<128x128xf16> %empty = tensor.empty() : tensor<128x128xf32> %cst = arith.constant 0.0 : f32 %5 = linalg.fill ins(%cst : f32) outs(%empty : tensor<128x128xf32>) -> tensor<128x128xf32> @@ -203,8 +165,7 @@ func.func @forall_fuse_then_hoist_with_fill() { } {mapping = [#gpu.thread, #gpu.thread]} scf.yield %11 : tensor<128x128xf32> } - flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor> - return + return %8 : tensor<128x128xf32> } // CHECK-LABEL: func @forall_fuse_then_hoist_with_fill @@ -214,24 +175,14 @@ func.func @forall_fuse_then_hoist_with_fill() { // CHECK: scf.yield {{.*}} : tensor<16x16xf32> // CHECK: scf.forall.in_parallel // CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]] -// CHECK: flow.dispatch.tensor.store %[[OUTER_PARALLEL]] +// CHECK: return %[[OUTER_PARALLEL]] // ----- -#pipeline_layout = #hal.pipeline.layout, - #hal.descriptor_set.binding<1, storage_buffer>, - #hal.descriptor_set.binding<2, storage_buffer> - ]> -]> -func.func @multi_hoist_and_fuse_trailing_stuff() { +func.func @multi_hoist_and_fuse_trailing_stuff(%2: tensor<128x128xf16>) -> tensor<128x128xf16> { %c4 = arith.constant 4 : index %c128 = arith.constant 128 : index %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<128x128xf16> %empty = tensor.empty() : tensor<128x128xf16> %8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %empty) -> (tensor<128x128xf16>) { %9 = scf.forall (%arg2, %arg3) in (2, 2) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf16>) { @@ -252,8 +203,7 @@ func.func @multi_hoist_and_fuse_trailing_stuff() { } %transpose = linalg.transpose ins(%8: tensor<128x128xf16>) outs(%empty: tensor<128x128xf16>) permutation = [1, 0] %ceil = linalg.ceil ins(%transpose: tensor<128x128xf16>) outs(%empty: tensor<128x128xf16>) -> tensor<128x128xf16> - flow.dispatch.tensor.store %ceil, %1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf16> -> !flow.dispatch.tensor> - return + return %ceil : tensor<128x128xf16> } // CHECK-LABEL: func @multi_hoist_and_fuse_trailing_stuff @@ -265,4 +215,48 @@ func.func @multi_hoist_and_fuse_trailing_stuff() { // CHECK: linalg.ceil ins(%[[T]] : tensor<4x2xf16>) {{.*}} -> tensor<4x2xf16> // CHECK: scf.forall.in_parallel // CHECK: scf.forall.in_parallel -// CHECK: flow.dispatch.tensor.store +// CHECK: return + +// ----- + +func.func @multi_hoist_and_fuse_trailing_with_producer_fusion(%2: tensor<128x128xf16>, %3: tensor<128x128xf16>) -> tensor<128x128xf16> { + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %empty = tensor.empty() : tensor<128x128xf16> + %8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %empty) -> (tensor<128x128xf16>) { + %9 = scf.forall (%arg2, %arg3) in (2, 2) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf16>) { + %extracted_slice = tensor.extract_slice %arg4[%arg2, %arg3] [64, 64] [1, 1] : tensor<128x128xf16> to tensor<64x64xf16> + %10 = scf.forall (%arg5, %arg6) in (32, 16) shared_outs(%arg7 = %extracted_slice) -> (tensor<64x64xf16>) { + %extracted_slice_1 = tensor.extract_slice %2[%arg5, %arg6] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16> + %extracted_slice_2 = tensor.extract_slice %arg7[%arg5, %arg6] [2, 4] [1, 1] : tensor<64x64xf16> to tensor<2x4xf16> + %16 = linalg.copy ins(%extracted_slice_1 : tensor<2x4xf16>) outs(%extracted_slice_2 : tensor<2x4xf16>) -> tensor<2x4xf16> + scf.forall.in_parallel { + tensor.parallel_insert_slice %16 into %arg7[%arg5, %arg6] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<64x64xf16> + } + } {mapping = [#gpu.thread, #gpu.thread]} + scf.forall.in_parallel { + tensor.parallel_insert_slice %10 into %arg4[%arg2, %arg3] [64, 64] [1, 1] : tensor<64x64xf16> into tensor<128x128xf16> + } + } {mapping = [#gpu.warp, #gpu.warp]} + scf.yield %9 : tensor<128x128xf16> + } + %transpose_input = linalg.transpose ins(%3: tensor<128x128xf16>) outs(%empty: tensor<128x128xf16>) permutation = [1, 0] + %add = linalg.add + ins(%8, %transpose_input : tensor<128x128xf16>, tensor<128x128xf16>) + outs(%empty: tensor<128x128xf16>) -> tensor<128x128xf16> + return %add : tensor<128x128xf16> +} + +// CHECK-LABEL: func @multi_hoist_and_fuse_trailing_with_producer_fusion +// CHECK-SAME: %[[I0:[A-Za-z0-9]+]]: tensor<128x128xf16> +// CHECK-SAME: %[[I1:[A-Za-z0-9]+]]: tensor<128x128xf16> +// CHECK: scf.forall +// CHECK: scf.forall +// CHECK: %[[LOOP:.+]] = scf.for {{.*}} -> (tensor<2x4xf16>) +// CHECK: linalg.copy +// CHECK: %[[T:.+]] = linalg.transpose ins(%{{.*}} : tensor<4x2xf16>) +// CHECK: linalg.add ins(%[[LOOP]], %[[T]] : tensor<2x4xf16>, tensor<2x4xf16>) {{.*}} -> tensor<2x4xf16> +// CHECK: scf.forall.in_parallel +// CHECK: scf.forall.in_parallel +// CHECK: return