Skip to content

Commit

Permalink
VNNI gemm fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KavithaTipturMadhu committed Dec 11, 2024
1 parent d7fcc28 commit 8441ee0
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 2 deletions.
5 changes: 3 additions & 2 deletions lib/TPP/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct LinalgGenericToVector : OpRewritePattern<linalg::GenericOp> {
if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) ==
xsmm::DataTypeAttr::get(rewriter.getContext(),
xsmm::DataType::BF16) &&
linalgOp.getIteratorTypes().size() >= 5 &&
linalgOp.getIteratorTypes().size() >= 4 &&
linalgOp.getNumOperands() == 3) {
SmallVector<int64_t> shape;
SmallVector<ReassociationIndices> indices;
Expand Down Expand Up @@ -72,7 +72,8 @@ struct LinalgGenericToVector : OpRewritePattern<linalg::GenericOp> {
}
auto map0 = linalgOp.getIndexingMapsArray()[0];
auto map1 = linalgOp.getIndexingMapsArray()[1];
map0 = map0.insertResult(map1.getResult(map1.getNumResults() - 1), 3);
map0 = map0.insertResult(map1.getResult(map1.getNumResults() - 1),
map0.getNumResults());
int map1Index = map1.getNumResults() - 3;
AffineExpr expr = map1.getResult(map1Index);
if (isa<AffineBinaryOpExpr>(expr)) {
Expand Down
72 changes: 72 additions & 0 deletions test/Passes/pass-vectorization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,75 @@ module {
// CHECK-NOT: %[[vec3:.*]] = vector.transfer_read
// CHECK-NOT: %[[vec4:.*]] = vector.contract
// CHECK-NOT: vector.transfer_write %[[vec4]]

// -----

#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3 floordiv 2, d2, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>

func.func @vnni_brgemm_strided(%arg0: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>,
%arg1: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>,
%arg2: memref<8x8xbf16>) {
linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]}
ins(%arg0, %arg1 : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>)
outs(%arg2 : memref<8x8xbf16>) {
^bb0(%in: bf16, %in_9: bf16, %out: bf16):
%11 = arith.mulf %in, %in_9 : bf16
%12 = arith.addf %out, %11 : bf16
linalg.yield %12 : bf16
}
return
}
// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)>
// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
//
// CHECK-LABEL: func.func @vnni_brgemm_strided(
// CHECK: %[[arg0:.*]]: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, %[[arg1:.*]]: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>, %[[arg2:.*]]: memref<8x8xbf16>) {
// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : bf16
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[expand_shape:.*]] = memref.expand_shape %[[arg0]] {{\[}}[0], [1], [2, 3]] output_shape [8, 8, 4, 2] : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>> into memref<8x8x4x2xbf16, strided<[64, 8, 2, 1], offset: ?>>
// CHECK: %[[read0:.*]] = vector.transfer_read %[[expand_shape]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true, true, true]}
// CHECK: %[[read1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %cst {in_bounds = [true, true, true, true]}
// CHECK: %[[read2:.*]] = vector.transfer_read %[[arg2]][%[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true]}
// CHECK: %[[read3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<add>} %[[read0]], %[[read1]], %[[read2]] : vector<8x8x4x2xbf16>, vector<8x4x8x2xbf16> into vector<8x8xbf16>
// CHECK: vector.transfer_write %[[read3]], %[[arg2]][%[[c0]], %[[c0]]] {in_bounds = [true, true]}

// -----

#map = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d3 floordiv 2, d2, d0)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>

func.func @non_square_vnni_gemm(%arg0: memref<64x16xbf16>,
%arg1: memref<8x64x2xbf16>, %arg2: memref<64x64xbf16>) {
linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : memref<64x16xbf16>, memref<8x64x2xbf16>)
outs(%arg2 : memref<64x64xbf16>) {
^bb0(%in: bf16, %in_2: bf16, %out: bf16):
%1 = arith.mulf %in, %in_2 : bf16
%2 = arith.addf %out, %1 : bf16
linalg.yield %2 : bf16
}
return
}

// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
// CHECK: #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
//
// CHECK-LABEL: func.func @non_square_vnni_gemm(
// CHECK: %[[arg0:.*]]: memref<64x16xbf16>, %[[arg1:.*]]: memref<8x64x2xbf16>, %[[arg2:.*]]: memref<64x64xbf16>) {
// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : bf16
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[expand_shape:.*]] = memref.expand_shape %[[arg0]] {{\[}}[0], [1, 2]] output_shape [64, 8, 2]
// CHECK %[[read0:.*]] = vector.transfer_read %[[expand_shape]][%[[c0]], %[[c0]], %[[c0]]], %cst {in_bounds = [true, true, true]}
// CHECK %[[read1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true, true]}
// CHECK: %[[read2:.*]] = vector.transfer_read %[[arg2]][%[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true]}
// CHECK: %[[read3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[read0]], %[[read1]], %[[read2]]
// CHECK: vector.transfer_write %3, %[[arg2]][%[[c0]], %[[c0]]] {in_bounds = [true, true]} : vector<64x64xbf16>, memref<64x64xbf16>

0 comments on commit 8441ee0

Please sign in to comment.