Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Nirvedh Meshram <[email protected]>
  • Loading branch information
nirvedhmeshram committed Jan 20, 2025
1 parent 737acd6 commit a291ccd
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK: func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
// CHECK: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
// CHECK: %[[CST:.+]] = arith.constant -3.40282347E+38 : f32
// CHECK: %[[CST:.+]] = arith.constant 0xFFC00000 : f32
// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
// CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
// CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
Expand Down Expand Up @@ -49,7 +49,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK-NO-FUSE: func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK-NO-FUSE: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
// CHECK-NO-FUSE: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
// CHECK-NO-FUSE: %[[CST:.+]] = arith.constant -3.40282347E+38 : f32
// CHECK-NO-FUSE: %[[CST:.+]] = arith.constant 0xFFC00000 : f32
// CHECK-NO-FUSE: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
// CHECK-NO-FUSE: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
// CHECK-NO-FUSE-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) {
// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
// CHECK: arith.maxnumf
// CHECK: arith.maxnumf
// CHECK: vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: scf.for {{.*}} -> (vector<4xf32>) {
// CHECK: vector.transfer_read
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ func.func @softmax() attributes {hal.executable.target = #executable_target_vulk
// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
// CHECK: arith.maxnumf
// CHECK: arith.maxnumf
// CHECK: vector.splat %{{.*}} : vector<4xf32>
// CHECK: scf.for {{.*}} -> (vector<4xf32>) {
// CHECK: vector.transfer_read
Expand Down
17 changes: 6 additions & 11 deletions compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matvec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,18 @@ hal.executable @i4_dequant_unit_matmul_f16 {

// CHECK-DAG: %[[CSTVEC4XI32_255:.+]] = spirv.Constant dense<255> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC4XI32_0:.+]] = spirv.Constant dense<0> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC4XI32_0_4:.+]] = spirv.Constant dense<[0, 4, 0, 4]> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC4XI32_15__16:.+]] = spirv.Constant dense<[15, -16, 15, -16]> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC2XI32_4:.+]] = spirv.Constant dense<4> : vector<2xi32>
// CHECK-DAG: %[[CSTVEC2XI32_15:.+]] = spirv.Constant dense<15> : vector<2xi32>

// CHECK: spirv.mlir.loop

// Load the quantized weight and get 8xi4 out of it.
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<4xi32>
// CHECK: %[[SHUF01:.+]] = spirv.VectorShuffle [0 : i32, 1 : i32] %[[LOAD]], %[[LOAD]] : vector<4xi32>, vector<4xi32> -> vector<2xi32>
// CHECK: %[[SHUF0011:.+]] = spirv.VectorShuffle [0 : i32, 0 : i32, 1 : i32, 1 : i32] %[[SHUF01]], %[[SHUF01]] : vector<2xi32>, vector<2xi32> -> vector<4xi32>
// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[SHUF0011]], %[[CSTVEC4XI32_15__16]] : vector<4xi32>
// CHECK: %[[SHIFTED:.+]] = spirv.ShiftRightLogical %[[MASKED]], %[[CSTVEC4XI32_0_4]] : vector<4xi32>, vector<4xi32>
// CHECK: %[[LOW4HIGH4_ZEROUPPER:.+]] = spirv.BitwiseAnd %[[SHIFTED]], %[[CSTVEC4XI32_255]] : vector<4xi32>
// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[SHUF01]], %[[CSTVEC2XI32_15]] : vector<2xi32>
// CHECK: %[[SHIFTED:.+]] = spirv.ShiftRightLogical %[[SHUF01]], %[[CSTVEC2XI32_4]] : vector<2xi32>, vector<2xi32>
// CHECK: %[[SHUF0011:.+]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[MASKED]], %[[SHIFTED]] : vector<2xi32>, vector<2xi32> -> vector<4xi32>
// CHECK: %[[LOW4HIGH4_ZEROUPPER:.+]] = spirv.BitwiseAnd %[[SHUF0011]], %[[CSTVEC4XI32_255]] : vector<4xi32>

// CHECK: %[[SHUF23:.+]] = spirv.VectorShuffle [2 : i32, 3 : i32] %[[LOAD:.+]], %[[LOAD:.+]] : vector<4xi32>, vector<4xi32> -> vector<2xi32>

Expand Down Expand Up @@ -186,8 +186,6 @@ hal.executable @i4_dequant_matvec_f16_subgroup_64 {
// CHECK-DAG: %[[C0:.+]] = spirv.Constant 0 : i32
// CHECK-DAG: %[[CSTVEC4XF16_1:.+]] = spirv.Constant dense<1.000000e+00> : vector<4xf16>
// CHECK-DAG: %[[CSTVEC4XI32_255:.+]] = spirv.Constant dense<255> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC2XI32_1:.+]] = spirv.Constant dense<[0, 4, 0, 4]> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC2XI32_2:.+]] = spirv.Constant dense<[15, -16, 15, -16]> : vector<4xi32>

// CHECK: %[[WIDX:.+]] = spirv.CompositeExtract %{{.*}}[0 : i32] : vector<3xi32>
// CHECK: %[[PCPTR:.+]] = spirv.AccessChain %{{.*}}[{{.*}}, %[[C0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>, i32, i32
Expand All @@ -209,9 +207,6 @@ hal.executable @i4_dequant_matvec_f16_subgroup_64 {
// CHECK: %[[ACCESS:.+]] = spirv.AccessChain %[[RADDR]][{{.*}}, %[[OFFSET]]] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer>, i32, i32
// CHECK: spirv.Load "StorageBuffer" %[[ACCESS]] : i32

// CHECK: spirv.ShiftRightLogical %{{.*}}, %[[CSTVEC2XI32_1]] : vector<4xi32>, vector<4xi32>
// CHECK: spirv.BitwiseAnd %{{.*}}, %[[CSTVEC4XI32_255]] : vector<4xi32>

// CHECK: spirv.ConvertUToF %{{.+}} : vector<4xi32> to vector<4xf16>
// CHECK: spirv.FSub %{{.+}}, %{{.+}} : vector<4xf16>
// CHECK-COUNT-2: spirv.FMul %{{.+}}, %{{.+}} : vector<4xf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,16 @@ hal.executable @i4_dequant {
// CHECK-LABEL: spirv.func @i4_dequant()

// CHECK: %[[BYTE1:.+]] = spirv.VectorShuffle [0 : i32, 1 : i32] {{.*}} : vector<4xi32>, vector<4xi32> -> vector<2xi32>
// CHECK: %[[COPIED:.+]] = spirv.VectorShuffle [0 : i32, 0 : i32, 1 : i32, 1 : i32] %[[BYTE1]], %[[BYTE1]] : vector<2xi32>, vector<2xi32> -> vector<4xi32>
// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[COPIED]]
// CHECK: %[[SHIFTED:.+]] = spirv.ShiftRightLogical %[[MASKED]]
// CHECK: %[[ZEROUPPER:.+]] = spirv.BitwiseAnd %[[SHIFTED]]
// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[BYTE1]]
// CHECK: %[[SHIFTED:.+]] = spirv.ShiftRightLogical %[[BYTE1]]
// CHECK: %[[COPIED:.+]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[MASKED]], %[[SHIFTED]] : vector<2xi32>, vector<2xi32> -> vector<4xi32>
// CHECK: %[[MASKED2:.+]] = spirv.BitwiseAnd %[[COPIED]]
// CHECK: spirv.VectorShuffle [2 : i32, 3 : i32] {{.*}} : vector<4xi32>, vector<4xi32> -> vector<2xi32>
// CHECK-COUNT-3: spirv.VectorShuffle [0 : i32, 0 : i32, 1 : i32, 1 : i32]
// CHECK: spirv.VectorShuffle [0 : i32, 1 : i32]
// CHECK: spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32]
// CHECK: spirv.VectorShuffle [2 : i32, 3 : i32]
// CHECK: spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32]
// CHECK-NOT: spirv.VectorShuffle

// CHECK-COUNT-4: spirv.ConvertUToF {{.+}} : vector<4xi32> to vector<4xf32>
// CHECK-COUNT-4: spirv.FSub
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2578,13 +2578,13 @@ module attributes { transform.with_named_sequence } {
transform.yield
}
}
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)>
// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
// CHECK: func @custom_op_index_handling(%[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xindex>,
// CHECK: scf.forall (%[[IV:[a-zA-Z0-9]+]],
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: iree_linalg_ext.custom_op
// CHECK-SAME: ins(%[[SLICE]]
// CHECK: %[[NEW_INDEX:.+]] = iree_linalg_ext.index 0 : index
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]](%[[NEW_INDEX]], %[[IV]])
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]](%[[IV]])[%[[NEW_INDEX]]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%{{.+}}, %[[INDEX]] :

0 comments on commit a291ccd

Please sign in to comment.