diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index 68aa58212..3fe4e10cb 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -417,7 +417,8 @@ def LinalgConvertCompareSelectToMaximumfPass: Pass<"linalg-convert-compare-selec let description = [{ Convert linalg generic compare-select operation to maximumf operation. }]; - let dependentDialects = ["linalg::LinalgDialect"]; + let dependentDialects = ["linalg::LinalgDialect", + "arith::ArithDialect"]; } def ConvertAddInplacePass: Pass<"linalg-convert-add-in-place", diff --git a/lib/TPP/Transforms/LinalgConvertCompareSelectToMaximumfPass.cpp b/lib/TPP/Transforms/LinalgConvertCompareSelectToMaximumfPass.cpp index eda3a3486..5ce760fec 100644 --- a/lib/TPP/Transforms/LinalgConvertCompareSelectToMaximumfPass.cpp +++ b/lib/TPP/Transforms/LinalgConvertCompareSelectToMaximumfPass.cpp @@ -56,13 +56,7 @@ struct LinalgConvertCompareSelectToMaximumf dyn_cast(op.getBody()->getOperations().begin()) ->getOperands()); dyn_cast(op.getBody()->getTerminator()).setOperand(0, maxf); - op.getOutputsMutable().clear(); - ValueRange range{op.getInputsMutable()}; - op.getOutputsMutable().append(range); - op.getInputsMutable().clear(); - op.setIndexingMapsAttr( - ArrayAttr::get(rewriter.getContext(), op.getIndexingMaps()[0])); - op.getBody()->eraseArgument(1); + // Deletion in reverse order due to dependences rewriter.eraseOp(select); rewriter.eraseOp(cmpf); diff --git a/test/Passes/linalg-convert-cmp-select-maximumf.mlir b/test/Passes/linalg-convert-cmp-select-maximumf.mlir index b04a44791..65ab8a691 100644 --- a/test/Passes/linalg-convert-cmp-select-maximumf.mlir +++ b/test/Passes/linalg-convert-cmp-select-maximumf.mlir @@ -1,26 +1,61 @@ -// RUN: tpp-opt --linalg-convert-compare-select-to-maximumf-pass %s --split-input-file | FileCheck %s +// RUN: tpp-opt %s --linalg-convert-compare-select-to-maximumf-pass --split-input-file | FileCheck %s -func.func @forward() -> tensor<256x1024xf32>{ -%cst_5 = arith.constant 0.000000e+00 : f32 -%5 = tensor.empty() : tensor<256x1024xf32> -%2 = tensor.empty() : tensor<256x1024xf32> -%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<256x1024xf32>) outs(%2 : tensor<256x1024xf32>) { - ^bb0(%in: f32, %out: f32): - %15 = arith.cmpf ugt, %in, %cst_5 : f32 - %16 = arith.select %15, %in, %cst_5 : f32 - linalg.yield %16 : f32 - } -> tensor<256x1024xf32> +func.func @compare_select_to_max(%arg0: tensor<256x1024xf32>) -> tensor<256x1024xf32>{ + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<256x1024xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<256x1024xf32>) outs(%0 : tensor<256x1024xf32>) { + ^bb0(%in: f32, %out: f32): + %15 = arith.cmpf ugt, %in, %cst : f32 + %16 = arith.select %15, %in, %cst : f32 + linalg.yield %16 : f32 + } -> tensor<256x1024xf32> -return %6: tensor<256x1024xf32> + return %1 : tensor<256x1024xf32> } // CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> // CHECK: module { -// CHECK: func.func @forward() -// CHECK: -> tensor<256x1024xf32> { +// CHECK: func.func @compare_select_to_max( +// CHECK-SAME: %[[ARG0:.+]]: tensor<256x1024xf32> // CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[temp0:.*]] = tensor.empty() : tensor<256x1024xf32> -// CHECK: %[[temp1:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[temp0]] : tensor<256x1024xf32>) { +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<256x1024xf32> +// CHECK: %[[temp1:.*]] = linalg.generic {indexing_maps = [#map, #map] +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[ARG0]] : tensor<256x1024xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<256x1024xf32>) { +// CHECK: ^bb0(%[[in:.*]]: f32, %[[out:.*]]: f32): +// CHECK: %[[temp2:.*]] = arith.maximumf %[[in]], %[[cst]] : f32 +// CHECK: linalg.yield %[[temp2]] : f32 +// CHECK: } -> tensor<256x1024xf32> +// CHECK: return %[[temp1]] : tensor<256x1024xf32> + + +// ----- + +func.func @compare_select_to_max_inplace(%arg0: tensor<256x1024xf32>) -> tensor<256x1024xf32>{ + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + outs(%arg0 : tensor<256x1024xf32>) { + ^bb0(%out: f32): + %15 = arith.cmpf ugt, %out, %cst : f32 + %16 = arith.select %15, %out, %cst : f32 + linalg.yield %16 : f32 + } -> tensor<256x1024xf32> + + return %0 : tensor<256x1024xf32> +} + +// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: module { +// CHECK: func.func @compare_select_to_max_inplace( +// CHECK-SAME: %[[ARG0:.+]]: tensor<256x1024xf32> +// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[temp1:.*]] = linalg.generic {indexing_maps = [#map] +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: outs(%[[ARG0]] : tensor<256x1024xf32>) { // CHECK: ^bb0(%[[out:.*]]: f32): // CHECK: %[[temp2:.*]] = arith.maximumf %[[out]], %[[cst]] : f32 // CHECK: linalg.yield %[[temp2]] : f32 @@ -41,7 +76,7 @@ func.func @non_zero_compare() -> tensor<256x1024xf32>{ linalg.yield %16 : f32 } -> tensor<256x1024xf32> -return %6: tensor<256x1024xf32> + return %6: tensor<256x1024xf32> } // CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>