diff --git a/test/Conversion/VectorToXsmm/vector-to-identity.mlir b/test/Conversion/VectorToXsmm/vector-to-identity.mlir index 5c56dd853..09f569d24 100644 --- a/test/Conversion/VectorToXsmm/vector-to-identity.mlir +++ b/test/Conversion/VectorToXsmm/vector-to-identity.mlir @@ -80,3 +80,59 @@ func.func @identity_subview_copy(%arg0: memref<128x1xf32>, %arg1: memref<512x128 // CHECK: %[[inttoptr4:.*]] = llvm.inttoptr %[[indexcast3]] // CHECK: call @xsmm_unary_invoke(%[[c1_i64]], %[[dispatch]], %[[inttoptr]], %[[c0]], %[[inttoptr4]], %[[c0]]) +// ----- + +func.func @identity_2d_bcast_to_3d(%arg0: memref<128x256xf32>, %arg1: memref<512x128x256xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x256xf32>, vector<128x256xf32> + %1 = vector.broadcast %0 : vector<128x256xf32> to vector<512x128x256xf32> + vector.transfer_write %1, %arg1[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<512x128x256xf32>, memref<512x128x256xf32> + return +} + +// CHECK-LABEL: func.func @identity_2d_bcast_to_3d( +// CHECK: %[[arg0:.*]]: memref<128x256xf32>, %[[arg1:.*]]: memref<512x128x256xf32>) { +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1_i64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[c65536_i64:.*]] = arith.constant 65536 : i64 +// CHECK-DAG: %[[c256_i64:.*]] = arith.constant 256 : i64 +// CHECK-DAG: %[[c32768_i64:.*]] = arith.constant 32768 : i64 +// CHECK-DAG: %[[c128_i64:.*]] = arith.constant 128 : i64 +// CHECK-DAG: %[[c4_i64:.*]] = arith.constant 4 : i64 +// CHECK: %[[dispatch:.*]] = call @xsmm_unary_dispatch(%[[c1_i64]], %[[c1_i64]], %[[c65536_i64]], %[[c256_i64]], %[[c32768_i64]], %[[c128_i64]], %[[c4_i64]]) +// CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[arg0]] +// CHECK: %[[indexcast:.*]] = arith.index_cast %[[intptr]] +// CHECK: %[[inttoptr:.*]] = llvm.inttoptr %[[indexcast]] +// CHECK: %[[intptr_0:.*]] = memref.extract_aligned_pointer_as_index %[[arg1]] +// CHECK: %[[indexcast3:.*]] = arith.index_cast %[[intptr_0]] +// CHECK: %[[inttoptr4:.*]] = llvm.inttoptr %[[indexcast3]] +// CHECK: call @xsmm_unary_invoke(%[[c1_i64]], %[[dispatch]], %[[inttoptr]], %[[c0]], %[[inttoptr4]], %[[c0]]) + +// ----- + +// Vector dialect's Verifiers throw an error, can't expect-error this or handle this in our validation +XFAIL:* +func.func @identity_size_mismatch(%arg0: memref<256xf32>, %arg1: memref<128x512xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %arg0[%c0], %cst {in_bounds = [true]} : memref<256xf32>, vector<256xf32> + %1 = vector.broadcast %0 : vector<256xf32> to vector<128x512xf32> + vector.transfer_write %1, %arg1[%c0, %c0] {in_bounds = [true, true]} : vector<128x512xf32>, memref<128x512xf32> + return +} + +// ---- + +// Vector dialect's Verifiers throw an error, can't expect-error this or handle this in our validation +XFAIL:* + +func.func @identity_rank_reduction(%arg0: memref<256x512f32>, %arg1: memref<512xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<256x512xf32>, vector<256x512xf32> + %1 = vector.broadcast %0 : vector<256x512xf32> to vector<512xf32> + vector.transfer_write %1, %arg1[%c0] {in_bounds = [true]} : vector<512xf32>, memref<512xf32> + return +} +