diff --git a/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp b/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp index e75e9f28b97f..41db09f07a16 100644 --- a/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" +#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" #include "iree/compiler/DispatchCreation/FusionUtils.h" #include "iree/compiler/DispatchCreation/Passes.h" #include "llvm/Support/Debug.h" @@ -35,6 +36,69 @@ struct ElementwiseOpFusionPass final void runOnOperation() override; }; +//===----------------------------------------------------------------------===// +// GatherFusionPattern +//===----------------------------------------------------------------------===// + +// Specific case. The linalg generic implementation of "gather" +// cannot be fused because it there is no producer-consumer +// relationship between the two generics. This is because the indexing +// is not affine (index values come from a tensor). +struct GatherFusionPattern final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + // Check if extractOp is inside a generic op + auto consumerOp = + dyn_cast_or_null(extractOp->getParentOp()); + if (!consumerOp) { + return rewriter.notifyMatchFailure( + extractOp, "expected extract op to be inside a generic op"); + } + + auto producerOp = extractOp.getTensor().getDefiningOp(); + if (!producerOp) { + return rewriter.notifyMatchFailure( + consumerOp, "expected extract operand to be a generic op"); + } + + // Check if the producerOp is fusible + if (producerOp.getNumDpsInputs() != 1 || producerOp.getNumResults() != 1 || + !isElementwise(producerOp) || + !IREE::LinalgExt::isBitExtendOp(producerOp)) { + return rewriter.notifyMatchFailure(producerOp, + "producer op is not fusible"); + } + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(extractOp); + + // Create a new extract op that extracts from the original tensor + // (after the original extract). Clone the producerOp's body into the + // consumerOp, inline the cloned block (erases the block) after the new + // extract, and clean up. + auto newExtractOp = rewriter.create( + extractOp.getLoc(), producerOp.getDpsInputOperand(0)->get(), + extractOp.getIndices()); + rewriter.cloneRegionBefore(producerOp.getRegion(), consumerOp.getRegion(), + consumerOp.getRegion().begin()); + Block &clonedBlock = consumerOp.getRegion().front(); + auto producerTermOp = clonedBlock.getTerminator(); + + rewriter.inlineBlockBefore( + &clonedBlock, extractOp->getNextNode(), + {newExtractOp.getResult(), newExtractOp.getResult()}); + + // Replace the the all references to the original extract result with the + // result from the inlined producerOp. + extractOp.getResult().replaceAllUsesWith(producerTermOp->getOperand(0)); + rewriter.eraseOp(producerTermOp); + rewriter.eraseOp(extractOp); + + return success(); + } +}; + } // namespace void ElementwiseOpFusionPass::runOnOperation() { @@ -82,6 +146,7 @@ void ElementwiseOpFusionPass::runOnOperation() { }; IREE::LinalgExt::populateFuseLinalgExtOpsWithTransposes( fusionPatterns, foldTransposeControlFn); + fusionPatterns.insert(context); GreedyRewriteConfig rewriteConfig; rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit; diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp index 3fdafd7f0246..158775571e30 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp @@ -149,76 +149,12 @@ struct FoldSuccessiveTensorInsertSliceOps final } }; -//===----------------------------------------------------------------------===// -// GatherFusionPattern -//===----------------------------------------------------------------------===// - -// Specific case. The linalg generic implementation of "gather" -// cannot be fused because it there is no producer-consumer -// relationship between the two generics. This is because the indexing -// is not affine (index values come from a tensor). -struct GatherFusionPattern final : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, - PatternRewriter &rewriter) const override { - // Check if extractOp is inside a generic op - auto consumerOp = - dyn_cast_or_null(extractOp->getParentOp()); - if (!consumerOp) { - return rewriter.notifyMatchFailure( - extractOp, "expected extract op to be inside a generic op"); - } - - auto producerOp = extractOp.getTensor().getDefiningOp(); - if (!producerOp) { - return rewriter.notifyMatchFailure( - consumerOp, "expected extract operand to be a generic op"); - } - - // Check if the producerOp is fusible - if (producerOp.getNumDpsInputs() != 1 || producerOp.getNumResults() != 1 || - !isElementwise(producerOp) || - !IREE::LinalgExt::isBitExtendOp(producerOp)) { - return rewriter.notifyMatchFailure(producerOp, - "producer op is not fusible"); - } - - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(extractOp); - - // Create a new extract op that extracts from the original tensor - // (after the original extract). Clone the producerOp's body into the - // consumerOp, inline the cloned block (erases the block) after the new - // extract, and clean up. - auto newExtractOp = rewriter.create( - extractOp.getLoc(), producerOp.getDpsInputOperand(0)->get(), - extractOp.getIndices()); - rewriter.cloneRegionBefore(producerOp.getRegion(), consumerOp.getRegion(), - consumerOp.getRegion().begin()); - Block &clonedBlock = consumerOp.getRegion().front(); - auto producerTermOp = clonedBlock.getTerminator(); - - rewriter.inlineBlockBefore( - &clonedBlock, extractOp->getNextNode(), - {newExtractOp.getResult(), newExtractOp.getResult()}); - - // Replace the the all references to the original extract result with the - // result from the inlined producerOp. - extractOp.getResult().replaceAllUsesWith(producerTermOp->getOperand(0)); - rewriter.eraseOp(producerTermOp); - rewriter.eraseOp(extractOp); - - return success(); - } -}; - struct FusionPreprocessingPass final : public impl::FusionPreprocessingPassBase { void runOnOperation() override { RewritePatternSet patterns(&getContext()); patterns.add( - &getContext()); + FoldSuccessiveTensorInsertSliceOps>(&getContext()); // Fold away `tensor.dim` operations that can be resolved in terms of its // operand shapes. diff --git a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel index f1a6c4b4bdbf..edb48fe68089 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel +++ b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel @@ -21,7 +21,7 @@ iree_lit_test_suite( "collapse_linalg_generic_on_tensors.mlir", "collapse_reduction.mlir", "attention_fuse_by_expansion.mlir", - "fold_transpose.mlir", + "elementwise_op_fusion.mlir", "dispatch_linalg_transform_dialect.mlir", "dispatch_region_formation_preprocessing.mlir", "fold_unit_dims.mlir", diff --git a/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt b/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt index 769ca7620613..02f6a1cfb174 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt @@ -27,7 +27,7 @@ iree_lit_test_suite( "dispatch_linalg_on_tensors_fusion_with_transpose.mlir" "dispatch_linalg_transform_dialect.mlir" "dispatch_region_formation_preprocessing.mlir" - "fold_transpose.mlir" + "elementwise_op_fusion.mlir" "fold_unit_dims.mlir" "form_dispatch_regions.mlir" "form_dispatch_workgroups.mlir" diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir b/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir similarity index 72% rename from compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir rename to compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir index 8164ff3ba121..8b556a03835d 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir @@ -124,3 +124,86 @@ util.func public @transpose_matmul(%arg0 : tensor<100x100xf16>, %arg1 : tensor<1 // CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)> // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] + +// ----- + +util.func public @fuse_generic_gather( + %11 :tensor<128256x4096xf16>, %12 : tensor<4x?xi64>, + %13 : tensor<4x?x4096xf32>, %14 : tensor<128256x4096xf32>) + -> tensor<4x?x4096xf32>{ + + %15 = linalg.generic { + indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%11 : tensor<128256x4096xf16>) + outs(%14 : tensor<128256x4096xf32>) { + ^bb0(%in: f16, %out: f32): + %17 = arith.extf %in : f16 to f32 + linalg.yield %17 : f32 + } -> tensor<128256x4096xf32> + %16 = linalg.generic { + indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%12 : tensor<4x?xi64>) + outs(%13 : tensor<4x?x4096xf32>) { + ^bb0(%in: i64, %out: f32): + %17 = arith.index_cast %in : i64 to index + %18 = linalg.index 2 : index + %extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32> + linalg.yield %extracted : f32 + } -> tensor<4x?x4096xf32> + util.return %16 : tensor<4x?x4096xf32> +} + +// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index +// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index +// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16> +// CHECK-NEXT: %[[RES:[a-zA-Z0-9]+]] = arith.extf %[[EXTRACTED]] : f16 to f32 +// CHECK-NEXT: linalg.yield %[[RES]] : f32 + + +// ----- + +util.func public @fuse_generic_gather2( + %11 :tensor<128256x4096xf16>, %12 : tensor<4x?xi64>, + %13 : tensor<4x?x4096xf32>, %14 : tensor<128256x4096xf32>) + -> tensor<4x?x4096xf32>{ + + %15 = linalg.generic { + indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%11 : tensor<128256x4096xf16>) + outs(%14 : tensor<128256x4096xf32>) { + ^bb0(%in: f16, %out: f32): + %17 = arith.extf %in : f16 to f32 + linalg.yield %17 : f32 + } -> tensor<128256x4096xf32> + %16 = linalg.generic { + indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%12 : tensor<4x?xi64>) + outs(%13 : tensor<4x?x4096xf32>) { + ^bb0(%in: i64, %out: f32): + %17 = arith.index_cast %in : i64 to index + %18 = linalg.index 2 : index + %extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32> + %result = arith.addf %extracted, %extracted : f32 + %result2 = arith.mulf %extracted, %extracted : f32 + %final = arith.addf %result, %result2 : f32 + linalg.yield %final: f32 + } -> tensor<4x?x4096xf32> + util.return %16 : tensor<4x?x4096xf32> +} + +// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index +// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index +// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16> +// CHECK-NEXT: %[[RES:[a-zA-Z0-9]+]] = arith.extf %[[EXTRACTED]] : f16 to f32 +// CHECK-NEXT: %[[RES2:[a-zA-Z0-9]+]] = arith.addf %[[RES]], %[[RES]] : f32 +// CHECK-NEXT: %[[RES3:[a-zA-Z0-9]+]] = arith.mulf %[[RES]], %[[RES]] : f32 +// CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32 +// CHECK-NEXT: linalg.yield %[[RES4]] : f32 diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fusion_preprocessing.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fusion_preprocessing.mlir index 209785d36867..14e7df57c127 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fusion_preprocessing.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fusion_preprocessing.mlir @@ -31,90 +31,6 @@ util.func public @fold_insert_slices(%source : tensor, // CHECK-SAME: [%[[NEW_OFFSET0]], %[[NEW_OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]] // CHECK: util.return %[[RETURN]] - -// ----- - -util.func public @fuse_generic_gather( - %11 :tensor<128256x4096xf16>, %12 : tensor<4x?xi64>, - %13 : tensor<4x?x4096xf32>, %14 : tensor<128256x4096xf32>) - -> tensor<4x?x4096xf32>{ - - %15 = linalg.generic { - indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%11 : tensor<128256x4096xf16>) - outs(%14 : tensor<128256x4096xf32>) { - ^bb0(%in: f16, %out: f32): - %17 = arith.extf %in : f16 to f32 - linalg.yield %17 : f32 - } -> tensor<128256x4096xf32> - %16 = linalg.generic { - indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1, d2)>], - iterator_types = ["parallel", "parallel", "parallel"]} - ins(%12 : tensor<4x?xi64>) - outs(%13 : tensor<4x?x4096xf32>) { - ^bb0(%in: i64, %out: f32): - %17 = arith.index_cast %in : i64 to index - %18 = linalg.index 2 : index - %extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32> - linalg.yield %extracted : f32 - } -> tensor<4x?x4096xf32> - util.return %16 : tensor<4x?x4096xf32> -} - -// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index -// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index -// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16> -// CHECK-NEXT: %[[RES:[a-zA-Z0-9]+]] = arith.extf %[[EXTRACTED]] : f16 to f32 -// CHECK-NEXT: linalg.yield %[[RES]] : f32 - - -// ----- - -util.func public @fuse_generic_gather2( - %11 :tensor<128256x4096xf16>, %12 : tensor<4x?xi64>, - %13 : tensor<4x?x4096xf32>, %14 : tensor<128256x4096xf32>) - -> tensor<4x?x4096xf32>{ - - %15 = linalg.generic { - indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%11 : tensor<128256x4096xf16>) - outs(%14 : tensor<128256x4096xf32>) { - ^bb0(%in: f16, %out: f32): - %17 = arith.extf %in : f16 to f32 - linalg.yield %17 : f32 - } -> tensor<128256x4096xf32> - %16 = linalg.generic { - indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1, d2)>], - iterator_types = ["parallel", "parallel", "parallel"]} - ins(%12 : tensor<4x?xi64>) - outs(%13 : tensor<4x?x4096xf32>) { - ^bb0(%in: i64, %out: f32): - %17 = arith.index_cast %in : i64 to index - %18 = linalg.index 2 : index - %extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32> - %result = arith.addf %extracted, %extracted : f32 - %result2 = arith.mulf %extracted, %extracted : f32 - %final = arith.addf %result, %result2 : f32 - linalg.yield %final: f32 - } -> tensor<4x?x4096xf32> - util.return %16 : tensor<4x?x4096xf32> -} - -// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index -// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index -// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16> -// CHECK-NEXT: %[[RES:[a-zA-Z0-9]+]] = arith.extf %[[EXTRACTED]] : f16 to f32 -// CHECK-NEXT: %[[RES2:[a-zA-Z0-9]+]] = arith.addf %[[RES]], %[[RES]] : f32 -// CHECK-NEXT: %[[RES3:[a-zA-Z0-9]+]] = arith.mulf %[[RES]], %[[RES]] : f32 -// CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32 -// CHECK-NEXT: linalg.yield %[[RES4]] : f32 - // ----- #ident = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>