Skip to content

Commit

Permalink
Support inplace generic in compare-select to max rewrite (#956)
Browse files Browse the repository at this point in the history
Simplifies conversion to only change generic's body operation without
affecting operation operands. This prevents crashes when generic has
inplace format.

Originally the pass would also implicitly rewrite generic operation to
be in place by replacing output with the input. This is removed to
improve robustness and better separate concerns. Inplace rewrite will be
handled by a separate pass later on.
  • Loading branch information
adam-smnk authored Aug 9, 2024
1 parent 2425032 commit 6bdaf15
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 25 deletions.
3 changes: 2 additions & 1 deletion include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ def LinalgConvertCompareSelectToMaximumfPass: Pass<"linalg-convert-compare-selec
let description = [{
Convert linalg generic compare-select operation to maximumf operation.
}];
let dependentDialects = ["linalg::LinalgDialect"];
let dependentDialects = ["linalg::LinalgDialect",
"arith::ArithDialect"];
}

def ConvertAddInplacePass: Pass<"linalg-convert-add-in-place",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,7 @@ struct LinalgConvertCompareSelectToMaximumf
dyn_cast<arith::CmpFOp>(op.getBody()->getOperations().begin())
->getOperands());
dyn_cast<YieldOp>(op.getBody()->getTerminator()).setOperand(0, maxf);
op.getOutputsMutable().clear();
ValueRange range{op.getInputsMutable()};
op.getOutputsMutable().append(range);
op.getInputsMutable().clear();
op.setIndexingMapsAttr(
ArrayAttr::get(rewriter.getContext(), op.getIndexingMaps()[0]));
op.getBody()->eraseArgument(1);

// Deletion in reverse order due to dependences
rewriter.eraseOp(select);
rewriter.eraseOp(cmpf);
Expand Down
69 changes: 52 additions & 17 deletions test/Passes/linalg-convert-cmp-select-maximumf.mlir
Original file line number Diff line number Diff line change
@@ -1,26 +1,61 @@
// RUN: tpp-opt --linalg-convert-compare-select-to-maximumf-pass %s --split-input-file | FileCheck %s
// RUN: tpp-opt %s --linalg-convert-compare-select-to-maximumf-pass --split-input-file | FileCheck %s

func.func @forward() -> tensor<256x1024xf32>{
%cst_5 = arith.constant 0.000000e+00 : f32
%5 = tensor.empty() : tensor<256x1024xf32>
%2 = tensor.empty() : tensor<256x1024xf32>
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<256x1024xf32>) outs(%2 : tensor<256x1024xf32>) {
^bb0(%in: f32, %out: f32):
%15 = arith.cmpf ugt, %in, %cst_5 : f32
%16 = arith.select %15, %in, %cst_5 : f32
linalg.yield %16 : f32
} -> tensor<256x1024xf32>
func.func @compare_select_to_max(%arg0: tensor<256x1024xf32>) -> tensor<256x1024xf32>{
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<256x1024xf32>
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<256x1024xf32>) outs(%0 : tensor<256x1024xf32>) {
^bb0(%in: f32, %out: f32):
%15 = arith.cmpf ugt, %in, %cst : f32
%16 = arith.select %15, %in, %cst : f32
linalg.yield %16 : f32
} -> tensor<256x1024xf32>

return %6: tensor<256x1024xf32>
return %1 : tensor<256x1024xf32>
}

// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: module {
// CHECK: func.func @forward()
// CHECK: -> tensor<256x1024xf32> {
// CHECK: func.func @compare_select_to_max(
// CHECK-SAME: %[[ARG0:.+]]: tensor<256x1024xf32>
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[temp0:.*]] = tensor.empty() : tensor<256x1024xf32>
// CHECK: %[[temp1:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[temp0]] : tensor<256x1024xf32>) {
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<256x1024xf32>
// CHECK: %[[temp1:.*]] = linalg.generic {indexing_maps = [#map, #map]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%[[ARG0]] : tensor<256x1024xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<256x1024xf32>) {
// CHECK: ^bb0(%[[in:.*]]: f32, %[[out:.*]]: f32):
// CHECK: %[[temp2:.*]] = arith.maximumf %[[in]], %[[cst]] : f32
// CHECK: linalg.yield %[[temp2]] : f32
// CHECK: } -> tensor<256x1024xf32>
// CHECK: return %[[temp1]] : tensor<256x1024xf32>


// -----

func.func @compare_select_to_max_inplace(%arg0: tensor<256x1024xf32>) -> tensor<256x1024xf32>{
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
outs(%arg0 : tensor<256x1024xf32>) {
^bb0(%out: f32):
%15 = arith.cmpf ugt, %out, %cst : f32
%16 = arith.select %15, %out, %cst : f32
linalg.yield %16 : f32
} -> tensor<256x1024xf32>

return %0 : tensor<256x1024xf32>
}

// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: module {
// CHECK: func.func @compare_select_to_max_inplace(
// CHECK-SAME: %[[ARG0:.+]]: tensor<256x1024xf32>
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[temp1:.*]] = linalg.generic {indexing_maps = [#map]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: outs(%[[ARG0]] : tensor<256x1024xf32>) {
// CHECK: ^bb0(%[[out:.*]]: f32):
// CHECK: %[[temp2:.*]] = arith.maximumf %[[out]], %[[cst]] : f32
// CHECK: linalg.yield %[[temp2]] : f32
Expand All @@ -41,7 +76,7 @@ func.func @non_zero_compare() -> tensor<256x1024xf32>{
linalg.yield %16 : f32
} -> tensor<256x1024xf32>

return %6: tensor<256x1024xf32>
return %6: tensor<256x1024xf32>
}

// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
Expand Down

0 comments on commit 6bdaf15

Please sign in to comment.