Skip to content

Commit

Permalink
Extra test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk committed Aug 9, 2024
1 parent fdbcb63 commit cdc95a8
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions test/Passes/convert-linalg-to-inplace.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,38 @@ func.func @generic_eltwise_unary_to_inplace(%arg0: tensor<8x4xf32>, %arg1: f32)

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @generic_eltwise_unary_dynamic_to_inplace(
%arg0: tensor<?x?xf32>, %arg1: f32) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%0 = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
%1 = linalg.generic {indexing_maps = [#map, #map],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.mulf %in, %arg1 : f32
linalg.yield %2 : f32
} -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>

// CHECK-LABEL: func.func @generic_eltwise_unary_dynamic_to_inplace(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[ARG1:.*]]: f32
// CHECK: linalg.generic{{.*}}indexing_maps = [#[[MAP]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: outs(%[[ARG0]] :{{.*}})
// CHECK: ^bb0(%[[out:.*]]: f32):
// CHECK: %[[RES:.*]] = arith.mulf %[[out]], %[[ARG1]]
// CHECK: linalg.yield %[[RES]]

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @no_inplace_generic_used_output(%arg0: tensor<8x4xf32>, %arg1: tensor<8x4xf32>) -> tensor<8x4xf32> {
%0 = linalg.generic {indexing_maps = [#map, #map],
Expand Down Expand Up @@ -148,3 +180,27 @@ func.func @no_inplace_generic_mismatched_maps(
// CHECK: %[[EMPTY:.+]] = tensor.empty
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[ARG0]] :{{.*}}) outs(%[[EMPTY]] :{{.*}})

// -----

#map = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
func.func @no_inplace_generic_transposed_maps(
%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
%cst = arith.constant 2.0 : f32
%0 = tensor.empty() : tensor<8x8xf32>
%1 = linalg.generic {indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.mulf %in, %cst : f32
linalg.yield %2 : f32
} -> tensor<8x8xf32>
return %1 : tensor<8x8xf32>
}

// CHECK-LABEL: func.func @no_inplace_generic_transposed_maps(
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[ARG0]] :{{.*}}) outs(%[[EMPTY]] :{{.*}})

0 comments on commit cdc95a8

Please sign in to comment.