From 6bdaf155193da52bdab84b1c3bedc98de416e0e2 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 9 Aug 2024 13:32:06 +0200 Subject: [PATCH] Support inplace generic in compare-select to max rewrite (#956) Simplifies conversion to only change generic's body operation without affecting operation operands. This prevents crashes when generic has inplace format. Originally the pass would also implicitly rewrite generic operation to be in place by replacing output with the input. This is removed to improve robustness and better separate concerns. Inplace rewrite will be handled by a separate pass later on. --- include/TPP/Passes.td | 3 +- ...nalgConvertCompareSelectToMaximumfPass.cpp | 8 +-- .../linalg-convert-cmp-select-maximumf.mlir | 69 ++++++++++++++----- 3 files changed, 55 insertions(+), 25 deletions(-) diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index 68aa58212..3fe4e10cb 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -417,7 +417,8 @@ def LinalgConvertCompareSelectToMaximumfPass: Pass<"linalg-convert-compare-selec let description = [{ Convert linalg generic compare-select operation to maximumf operation. }]; - let dependentDialects = ["linalg::LinalgDialect"]; + let dependentDialects = ["linalg::LinalgDialect", + "arith::ArithDialect"]; } def ConvertAddInplacePass: Pass<"linalg-convert-add-in-place", diff --git a/lib/TPP/Transforms/LinalgConvertCompareSelectToMaximumfPass.cpp b/lib/TPP/Transforms/LinalgConvertCompareSelectToMaximumfPass.cpp index eda3a3486..5ce760fec 100644 --- a/lib/TPP/Transforms/LinalgConvertCompareSelectToMaximumfPass.cpp +++ b/lib/TPP/Transforms/LinalgConvertCompareSelectToMaximumfPass.cpp @@ -56,13 +56,7 @@ struct LinalgConvertCompareSelectToMaximumf dyn_cast(op.getBody()->getOperations().begin()) ->getOperands()); dyn_cast(op.getBody()->getTerminator()).setOperand(0, maxf); - op.getOutputsMutable().clear(); - ValueRange range{op.getInputsMutable()}; - op.getOutputsMutable().append(range); - op.getInputsMutable().clear(); - op.setIndexingMapsAttr( - ArrayAttr::get(rewriter.getContext(), op.getIndexingMaps()[0])); - op.getBody()->eraseArgument(1); + // Deletion in reverse order due to dependences rewriter.eraseOp(select); rewriter.eraseOp(cmpf); diff --git a/test/Passes/linalg-convert-cmp-select-maximumf.mlir b/test/Passes/linalg-convert-cmp-select-maximumf.mlir index b04a44791..65ab8a691 100644 --- a/test/Passes/linalg-convert-cmp-select-maximumf.mlir +++ b/test/Passes/linalg-convert-cmp-select-maximumf.mlir @@ -1,26 +1,61 @@ -// RUN: tpp-opt --linalg-convert-compare-select-to-maximumf-pass %s --split-input-file | FileCheck %s +// RUN: tpp-opt %s --linalg-convert-compare-select-to-maximumf-pass --split-input-file | FileCheck %s -func.func @forward() -> tensor<256x1024xf32>{ -%cst_5 = arith.constant 0.000000e+00 : f32 -%5 = tensor.empty() : tensor<256x1024xf32> -%2 = tensor.empty() : tensor<256x1024xf32> -%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<256x1024xf32>) outs(%2 : tensor<256x1024xf32>) { - ^bb0(%in: f32, %out: f32): - %15 = arith.cmpf ugt, %in, %cst_5 : f32 - %16 = arith.select %15, %in, %cst_5 : f32 - linalg.yield %16 : f32 - } -> tensor<256x1024xf32> +func.func @compare_select_to_max(%arg0: tensor<256x1024xf32>) -> tensor<256x1024xf32>{ + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<256x1024xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<256x1024xf32>) outs(%0 : tensor<256x1024xf32>) { + ^bb0(%in: f32, %out: f32): + %15 = arith.cmpf ugt, %in, %cst : f32 + %16 = arith.select %15, %in, %cst : f32 + linalg.yield %16 : f32 + } -> tensor<256x1024xf32> -return %6: tensor<256x1024xf32> + return %1 : tensor<256x1024xf32> } // CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> // CHECK: module { -// CHECK: func.func @forward() -// CHECK: -> tensor<256x1024xf32> { +// CHECK: func.func @compare_select_to_max( +// CHECK-SAME: %[[ARG0:.+]]: tensor<256x1024xf32> // CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[temp0:.*]] = tensor.empty() : tensor<256x1024xf32> -// CHECK: %[[temp1:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[temp0]] : tensor<256x1024xf32>) { +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<256x1024xf32> +// CHECK: %[[temp1:.*]] = linalg.generic {indexing_maps = [#map, #map] +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[ARG0]] : tensor<256x1024xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<256x1024xf32>) { +// CHECK: ^bb0(%[[in:.*]]: f32, %[[out:.*]]: f32): +// CHECK: %[[temp2:.*]] = arith.maximumf %[[in]], %[[cst]] : f32 +// CHECK: linalg.yield %[[temp2]] : f32 +// CHECK: } -> tensor<256x1024xf32> +// CHECK: return %[[temp1]] : tensor<256x1024xf32> + + +// ----- + +func.func @compare_select_to_max_inplace(%arg0: tensor<256x1024xf32>) -> tensor<256x1024xf32>{ + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + outs(%arg0 : tensor<256x1024xf32>) { + ^bb0(%out: f32): + %15 = arith.cmpf ugt, %out, %cst : f32 + %16 = arith.select %15, %out, %cst : f32 + linalg.yield %16 : f32 + } -> tensor<256x1024xf32> + + return %0 : tensor<256x1024xf32> +} + +// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: module { +// CHECK: func.func @compare_select_to_max_inplace( +// CHECK-SAME: %[[ARG0:.+]]: tensor<256x1024xf32> +// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[temp1:.*]] = linalg.generic {indexing_maps = [#map] +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: outs(%[[ARG0]] : tensor<256x1024xf32>) { // CHECK: ^bb0(%[[out:.*]]: f32): // CHECK: %[[temp2:.*]] = arith.maximumf %[[out]], %[[cst]] : f32 // CHECK: linalg.yield %[[temp2]] : f32 @@ -41,7 +76,7 @@ func.func @non_zero_compare() -> tensor<256x1024xf32>{ linalg.yield %16 : f32 } -> tensor<256x1024xf32> -return %6: tensor<256x1024xf32> + return %6: tensor<256x1024xf32> } // CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>