Skip to content

Commit

Permalink
Updated unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
KavithaTipturMadhu committed Dec 12, 2024
1 parent 12a84e1 commit 07502f2
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions test/Conversion/VectorToXsmm/vector-to-identity.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,59 @@ func.func @identity_subview_copy(%arg0: memref<128x1xf32>, %arg1: memref<512x128
// CHECK: %[[inttoptr4:.*]] = llvm.inttoptr %[[indexcast3]]
// CHECK: call @xsmm_unary_invoke(%[[c1_i64]], %[[dispatch]], %[[inttoptr]], %[[c0]], %[[inttoptr4]], %[[c0]])

// -----

func.func @identity_2d_bcast_to_3d(%arg0: memref<128x256xf32>, %arg1: memref<512x128x256xf32>) {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x256xf32>, vector<128x256xf32>
%1 = vector.broadcast %0 : vector<128x256xf32> to vector<512x128x256xf32>
vector.transfer_write %1, %arg1[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<512x128x256xf32>, memref<512x128x256xf32>
return
}

// CHECK-LABEL: func.func @identity_2d_bcast_to_3d(
// CHECK: %[[arg0:.*]]: memref<128x256xf32>, %[[arg1:.*]]: memref<512x128x256xf32>) {
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1_i64:.*]] = arith.constant 1 : i64
// CHECK-DAG: %[[c65536_i64:.*]] = arith.constant 65536 : i64
// CHECK-DAG: %[[c256_i64:.*]] = arith.constant 256 : i64
// CHECK-DAG: %[[c32768_i64:.*]] = arith.constant 32768 : i64
// CHECK-DAG: %[[c128_i64:.*]] = arith.constant 128 : i64
// CHECK-DAG: %[[c4_i64:.*]] = arith.constant 4 : i64
// CHECK: %[[dispatch:.*]] = call @xsmm_unary_dispatch(%[[c1_i64]], %[[c1_i64]], %[[c65536_i64]], %[[c256_i64]], %[[c32768_i64]], %[[c128_i64]], %[[c4_i64]])
// CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[arg0]]
// CHECK: %[[indexcast:.*]] = arith.index_cast %[[intptr]]
// CHECK: %[[inttoptr:.*]] = llvm.inttoptr %[[indexcast]]
// CHECK: %[[intptr_0:.*]] = memref.extract_aligned_pointer_as_index %[[arg1]]
// CHECK: %[[indexcast3:.*]] = arith.index_cast %[[intptr_0]]
// CHECK: %[[inttoptr4:.*]] = llvm.inttoptr %[[indexcast3]]
// CHECK: call @xsmm_unary_invoke(%[[c1_i64]], %[[dispatch]], %[[inttoptr]], %[[c0]], %[[inttoptr4]], %[[c0]])

// -----

// Vector dialect's Verifiers throw an error, can't expect-error this or handle this in our validation
XFAIL:*
func.func @identity_size_mismatch(%arg0: memref<256xf32>, %arg1: memref<128x512xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%c0], %cst {in_bounds = [true]} : memref<256xf32>, vector<256xf32>
%1 = vector.broadcast %0 : vector<256xf32> to vector<128x512xf32>
vector.transfer_write %1, %arg1[%c0, %c0] {in_bounds = [true, true]} : vector<128x512xf32>, memref<128x512xf32>
return
}

// ----

// Vector dialect's Verifiers throw an error, can't expect-error this or handle this in our validation
XFAIL:*

func.func @identity_rank_reduction(%arg0: memref<256x512f32>, %arg1: memref<512xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<256x512xf32>, vector<256x512xf32>
%1 = vector.broadcast %0 : vector<256x512xf32> to vector<512xf32>
vector.transfer_write %1, %arg1[%c0] {in_bounds = [true]} : vector<512xf32>, memref<512xf32>
return
}

0 comments on commit 07502f2

Please sign in to comment.