From 4fe9aa7728528114078c3a407fce31c3759a7a5c Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 7 Aug 2024 19:44:23 +0200 Subject: [PATCH 1/2] Fold into eltwise pass - fold broadcast --- include/TPP/Passes.td | 11 ++ lib/TPP/Transforms/CMakeLists.txt | 1 + lib/TPP/Transforms/FoldIntoEltwise.cpp | 116 +++++++++++++ test/Passes/pass-fold-into-eltwise.mlir | 208 ++++++++++++++++++++++++ 4 files changed, 336 insertions(+) create mode 100644 lib/TPP/Transforms/FoldIntoEltwise.cpp create mode 100644 test/Passes/pass-fold-into-eltwise.mlir diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index 8ec2bef5e..68aa58212 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -504,4 +504,15 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> { ]; } +def FoldIntoEltwise : Pass<"fold-into-eltwise", "ModuleOp"> { + let summary = "Fold operations into elementwise ops."; + let description = [{ + Fold operations into Linalg elementwise ops. + Results in linalg.generic representation. + }]; + let dependentDialects = ["linalg::LinalgDialect", + "arith::ArithDialect", + "affine::AffineDialect"]; +} + #endif // TPP_DIALECT_TPP_PASSES diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index fb7385b5c..4f2def0f0 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_library(TPPTransforms IntelAMXTileConfigHoisting.cpp LinalgConvertCompareSelectToMaximumfPass.cpp ConvertAddInplacePass.cpp + FoldIntoEltwise.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/TPP diff --git a/lib/TPP/Transforms/FoldIntoEltwise.cpp b/lib/TPP/Transforms/FoldIntoEltwise.cpp new file mode 100644 index 000000000..e41cfe74d --- /dev/null +++ b/lib/TPP/Transforms/FoldIntoEltwise.cpp @@ -0,0 +1,116 @@ +//===- FoldIntoEltwise.cpp ---------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TPP/Passes.h" +#include "TPP/Transforms/Transforms.h" +#include "TPP/Transforms/Utils/TransformUtils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace tpp; + +namespace mlir { +namespace tpp { +#define GEN_PASS_DEF_FOLDINTOELTWISE +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +namespace { + +// Create affine map between a producer operand and a consumer op indexing. +// Assumes that the provided maps are composable. +static AffineMap +reindexProducerOperandIntoConsumer(AffineMap producerOperandMap, + AffineMap producerResultMap, + AffineMap consumerMap) { + // Index producer result dimensions to its loops. + AffineMap invProducerResultMap = inversePermutation(producerResultMap); + // Index producer operand with respect to the producer result dimensions. + AffineMap operandToResultMap = + producerOperandMap.compose(invProducerResultMap); + // Remap producer operand into consumer indexing. + return operandToResultMap.compose(consumerMap); +} + +// Fold linalg.broadcast into a linalg elementwise operation. +struct BroadcastIntoEltwise + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, + PatternRewriter &rewriter) const override { + if (!linalg::isElementwise(linalgOp)) + return rewriter.notifyMatchFailure(linalgOp, + "not an elementwise operation"); + + if (!linalgOp.hasPureTensorSemantics()) + return rewriter.notifyMatchFailure(linalgOp, "expects tensor semantics"); + + // Look for broadcasts within inputs. + // Reshaping output might be less beneficial and it is not considered now. + if (llvm::none_of(linalgOp.getDpsInputs(), [](Value input) { + auto op = input.getDefiningOp(); + return op && isa(op); + })) + return rewriter.notifyMatchFailure(linalgOp, "no broadcast producers"); + + SmallVector inputs = linalgOp.getDpsInputs(); + ValueRange outputs = linalgOp.getDpsInits(); + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + SmallVector iterators = + linalgOp.getIteratorTypesArray(); + SmallVector resultTypes = TypeRange(ValueRange{outputs}); + + for (auto [idx, input] : llvm::enumerate(linalgOp.getDpsInputs())) { + auto broadcast = input.getDefiningOp(); + if (!broadcast) + continue; + + // Update indexing maps. + // The broadcasting can be captured by indexing maps alone w.r.t broadcast + // input and consumer iteration domain. + indexingMaps[idx] = reindexProducerOperandIntoConsumer( + broadcast.getMatchingIndexingMap(broadcast.getDpsInputOperand(0)), + broadcast.getMatchingIndexingMap(broadcast.getDpsInitOperand(0)), + indexingMaps[idx]); + // Use the broadcast input directly instead of the broadcast result. + inputs[idx] = broadcast.getInput(); + } + + // All Linalg ops have a region attached that can be inlined. + assert(linalgOp->getNumRegions() == 1 && + "expect op to have one region attached"); + // Replace the original op with a generic with broadcast folded in. + auto genericOp = rewriter.create( + linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, + iterators); + rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(), + genericOp.getRegion().begin()); + rewriter.replaceOp(linalgOp, genericOp->getResults()); + + return success(); + } +}; + +struct FoldIntoEltwise : tpp::impl::FoldIntoEltwiseBase { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(patterns.getContext()); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace diff --git a/test/Passes/pass-fold-into-eltwise.mlir b/test/Passes/pass-fold-into-eltwise.mlir new file mode 100644 index 000000000..17c17f666 --- /dev/null +++ b/test/Passes/pass-fold-into-eltwise.mlir @@ -0,0 +1,208 @@ +// RUN: tpp-opt %s -fold-into-eltwise -split-input-file | FileCheck %s + +func.func @broadcast_into_add_outer_dim(%arg0: tensor<8xf32>, + %arg1: tensor<16x8xf32>) -> tensor<16x8xf32> { + %e = tensor.empty() : tensor<16x8xf32> + %0 = linalg.broadcast ins(%arg0 : tensor<8xf32>) outs(%e : tensor<16x8xf32>) dimensions = [0] + %1 = linalg.add ins(%0, %arg1 : tensor<16x8xf32>, tensor<16x8xf32>) + outs(%e : tensor<16x8xf32>) -> tensor<16x8xf32> + return %1 : tensor<16x8xf32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @broadcast_into_add_outer_dim( +// CHECK-SAME: %[[ARG0:.+]]: tensor<8xf32> +// CHECK-NOT: linalg.broadcast +// CHECK: linalg.generic{{.*}}indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP1]]] +// CHECK-SAME: ins(%[[ARG0]],{{.*}}) +// CHECK: arith.addf +// CHECK: linalg.yield + +// ----- + +func.func @broadcast_into_add_inner_dim(%arg0: tensor<8xf32>, + %arg1: tensor<8x4xf32>) -> tensor<8x4xf32> { + %e = tensor.empty() : tensor<8x4xf32> + %0 = linalg.broadcast ins(%arg0 : tensor<8xf32>) outs(%e : tensor<8x4xf32>) dimensions = [1] + %1 = linalg.add ins(%0, %arg1 : tensor<8x4xf32>, tensor<8x4xf32>) + outs(%e : tensor<8x4xf32>) -> tensor<8x4xf32> + return %1 : tensor<8x4xf32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @broadcast_into_add_inner_dim( +// CHECK-SAME: %[[ARG0:.+]]: tensor<8xf32> +// CHECK-NOT: linalg.broadcast +// CHECK: linalg.generic{{.*}}indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP1]]] +// CHECK-SAME: ins(%[[ARG0]],{{.*}}) +// CHECK: arith.addf +// CHECK: linalg.yield + +// ----- + +func.func @broadcast_into_mul(%arg0: tensor<8xf32>, + %arg1: tensor<8x4xf32>) -> tensor<8x4xf32> { + %e = tensor.empty() : tensor<8x4xf32> + %0 = linalg.broadcast ins(%arg0 : tensor<8xf32>) outs(%e : tensor<8x4xf32>) dimensions = [1] + %1 = linalg.mul ins(%0, %arg1 : tensor<8x4xf32>, tensor<8x4xf32>) + outs(%e : tensor<8x4xf32>) -> tensor<8x4xf32> + return %1 : tensor<8x4xf32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @broadcast_into_mul( +// CHECK-SAME: %[[ARG0:.+]]: tensor<8xf32> +// CHECK-NOT: linalg.broadcast +// CHECK: linalg.generic{{.*}}indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP1]]] +// CHECK-SAME: ins(%[[ARG0]],{{.*}}) +// CHECK: arith.mulf +// CHECK: linalg.yield + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func.func @broadcast_into_generic(%arg0: tensor<4xf32>, + %arg1: tensor<4x2x8xf32>) -> tensor<4x2x8xf32> { + %e = tensor.empty() : tensor<4x8xf32> + %0 = linalg.broadcast ins(%arg0 : tensor<4xf32>) outs(%e : tensor<4x8xf32>) dimensions = [1] + %1 = linalg.generic {indexing_maps = [#map, #map1], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x2x8xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<4x2x8xf32> + return %1 : tensor<4x2x8xf32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +// CHECK-LABEL: @broadcast_into_generic( +// CHECK-SAME: %[[ARG0:.+]]: tensor<4xf32> +// CHECK-NOT: linalg.broadcast +// CHECK: linalg.generic{{.*}}indexing_maps = [#[[MAP]], #[[MAP1]]] +// CHECK-SAME: ins(%[[ARG0]] :{{.*}}) + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d2, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func.func @broadcast_into_generic_transposed(%arg0: tensor<8xf32>, + %arg1: tensor<4x2x8xf32>) -> tensor<4x2x8xf32> { + %e = tensor.empty() : tensor<8x4xf32> + %0 = linalg.broadcast ins(%arg0 : tensor<8xf32>) outs(%e : tensor<8x4xf32>) dimensions = [1] + %1 = linalg.generic {indexing_maps = [#map, #map1], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0 : tensor<8x4xf32>) outs(%arg1 : tensor<4x2x8xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<4x2x8xf32> + return %1 : tensor<4x2x8xf32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +// CHECK-LABEL: @broadcast_into_generic_transposed( +// CHECK-SAME: %[[ARG0:.+]]: tensor<8xf32> +// CHECK-NOT: linalg.broadcast +// CHECK: linalg.generic{{.*}}indexing_maps = [#[[MAP]], #[[MAP1]]] +// CHECK-SAME: ins(%[[ARG0]] :{{.*}}) + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @broadcast_into_generic_multidim(%arg0: tensor<6x2xf32>, + %arg1: tensor<2x4x6x8xf32>) -> tensor<2x4x6x8xf32> { + %e = tensor.empty() : tensor<6x4x2xf32> + %0 = linalg.broadcast ins(%arg0 : tensor<6x2xf32>) outs(%e : tensor<6x4x2xf32>) dimensions = [1] + %1 = linalg.generic {indexing_maps = [#map, #map1], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%0 : tensor<6x4x2xf32>) outs(%arg1 : tensor<2x4x6x8xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<2x4x6x8xf32> + return %1 : tensor<2x4x6x8xf32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +// CHECK-LABEL: @broadcast_into_generic_multidim( +// CHECK-SAME: %[[ARG0:.+]]: tensor<6x2xf32> +// CHECK-NOT: linalg.broadcast +// CHECK: linalg.generic{{.*}}indexing_maps = [#[[MAP]], #[[MAP1]]] +// CHECK-SAME: ins(%[[ARG0]] :{{.*}}) + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @broadcast_into_generic_multiple_operands(%arg0: tensor<4xf32>, + %arg1: tensor<4x8xf32>, %arg2: tensor<4x8xf32>) -> tensor<4x8xf32> { + %e = tensor.empty() : tensor<4x8xf32> + %0 = linalg.broadcast ins(%arg0 : tensor<4xf32>) outs(%e : tensor<4x8xf32>) dimensions = [1] + %1 = linalg.generic {indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel"]} + ins(%arg1, %0 : tensor<4x8xf32>, tensor<4x8xf32>) outs(%arg2 : tensor<4x8xf32>) { + ^bb0(%in: f32, %in1: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + %2 = arith.addf %1, %in1 : f32 + linalg.yield %2 : f32 + } -> tensor<4x8xf32> + return %1 : tensor<4x8xf32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)> + +// CHECK-LABEL: @broadcast_into_generic_multiple_operands( +// CHECK-SAME: %[[ARG0:.+]]: tensor<4xf32>, +// CHECK-SAME: %[[ARG1:.+]]: tensor<4x8xf32>, +// CHECK-SAME: %[[ARG2:.+]]: tensor<4x8xf32> +// CHECK-NOT: linalg.broadcast +// CHECK: linalg.generic{{.*}}indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP]]] +// CHECK-SAME: ins(%[[ARG1]], %[[ARG0]] :{{.*}}) + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @no_fold_non_eltwise(%arg0: tensor<16xf32>, + %arg1: tensor<32x64xf32>, + %arg2: tensor<16x64xf32>) -> tensor<16x64xf32> { + %e = tensor.empty() : tensor<16x32xf32> + %0 = linalg.broadcast ins(%arg0 : tensor<16xf32>) outs(%e : tensor<16x32xf32>) dimensions = [1] + %1 = linalg.matmul ins(%0, %arg1 : tensor<16x32xf32>, tensor<32x64xf32>) + outs(%arg2 : tensor<16x64xf32>) -> tensor<16x64xf32> + return %1 : tensor<16x64xf32> +} + +// CHECK-LABEL: @no_fold_non_eltwise( +// CHECK: linalg.broadcast +// CHECK: linalg.matmul + +// ----- + +func.func @no_fold_non_tensor(%arg0: memref<8xf32>, + %arg1: memref<8x4xf32>, + %arg2: memref<8x4xf32>) { + linalg.broadcast ins(%arg0 : memref<8xf32>) outs(%arg2 : memref<8x4xf32>) dimensions = [1] + linalg.add ins(%arg2, %arg1 : memref<8x4xf32>, memref<8x4xf32>) + outs(%arg2 : memref<8x4xf32>) + return +} + +// CHECK-LABEL: @no_fold_non_tensor( +// CHECK: linalg.broadcast +// CHECK: linalg.add From 7d31572c9bb78fe533cfaf82ec00bf2684df7aed Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 7 Aug 2024 20:42:18 +0200 Subject: [PATCH 2/2] Add pass to pipeline --- lib/TPP/DefaultTppPasses.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/TPP/DefaultTppPasses.cpp b/lib/TPP/DefaultTppPasses.cpp index e768225ee..5cb198d45 100644 --- a/lib/TPP/DefaultTppPasses.cpp +++ b/lib/TPP/DefaultTppPasses.cpp @@ -77,6 +77,7 @@ struct DefaultTppPasses pm.addNestedPass(createConvertLinalgToLoopsPass()); pm.addPass(createCleanup()); } else { + pm.addPass(createFoldIntoEltwise()); pm.addNestedPass(createConvertAddInplacePass()); // Convert linalg.batch_matmul to linalg.matmul. pm.addPass(createRewriteBatchMatmulToMatmul());