diff --git a/lib/TPP/Transforms/Vectorization.cpp b/lib/TPP/Transforms/Vectorization.cpp index 38406f0f9..0073da373 100644 --- a/lib/TPP/Transforms/Vectorization.cpp +++ b/lib/TPP/Transforms/Vectorization.cpp @@ -42,7 +42,7 @@ struct LinalgGenericToVector : OpRewritePattern { if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) == xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16) && - linalgOp.getIteratorTypes().size() >= 5 && + linalgOp.getIteratorTypes().size() >= 4 && linalgOp.getNumOperands() == 3) { SmallVector shape; SmallVector indices; @@ -72,7 +72,8 @@ struct LinalgGenericToVector : OpRewritePattern { } auto map0 = linalgOp.getIndexingMapsArray()[0]; auto map1 = linalgOp.getIndexingMapsArray()[1]; - map0 = map0.insertResult(map1.getResult(map1.getNumResults() - 1), 3); + map0 = map0.insertResult(map1.getResult(map1.getNumResults() - 1), + map0.getNumResults()); int map1Index = map1.getNumResults() - 3; AffineExpr expr = map1.getResult(map1Index); if (isa(expr)) { diff --git a/test/Passes/pass-vectorization.mlir b/test/Passes/pass-vectorization.mlir index c99fc6d32..257a21c8f 100644 --- a/test/Passes/pass-vectorization.mlir +++ b/test/Passes/pass-vectorization.mlir @@ -110,3 +110,75 @@ module { // CHECK-NOT: %[[vec3:.*]] = vector.transfer_read // CHECK-NOT: %[[vec4:.*]] = vector.contract // CHECK-NOT: vector.transfer_write %[[vec4]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3 floordiv 2, d2, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)> + +func.func @vnni_brgemm_strided(%arg0: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, + %arg1: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>, + %arg2: memref<8x8xbf16>) { + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%arg0, %arg1 : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>) + outs(%arg2 : memref<8x8xbf16>) { + ^bb0(%in: bf16, %in_9: bf16, %out: bf16): + %11 = arith.mulf %in, %in_9 : bf16 + %12 = arith.addf %out, %11 : bf16 + linalg.yield %12 : bf16 + } + return +} +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)> +// +// CHECK-LABEL: func.func @vnni_brgemm_strided( +// CHECK: %[[arg0:.*]]: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, %[[arg1:.*]]: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>, %[[arg2:.*]]: memref<8x8xbf16>) { +// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[expand_shape:.*]] = memref.expand_shape %[[arg0]] {{\[}}[0], [1], [2, 3]] output_shape [8, 8, 4, 2] : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>> into memref<8x8x4x2xbf16, strided<[64, 8, 2, 1], offset: ?>> +// CHECK: %[[read0:.*]] = vector.transfer_read %[[expand_shape]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true, true, true]} +// CHECK: %[[read1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %cst {in_bounds = [true, true, true, true]} +// CHECK: %[[read2:.*]] = vector.transfer_read %[[arg2]][%[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true]} +// CHECK: %[[read3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} %[[read0]], %[[read1]], %[[read2]] : vector<8x8x4x2xbf16>, vector<8x4x8x2xbf16> into vector<8x8xbf16> +// CHECK: vector.transfer_write %[[read3]], %[[arg2]][%[[c0]], %[[c0]]] {in_bounds = [true, true]} + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d3 floordiv 2, d2, d0)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +func.func @non_square_vnni_gemm(%arg0: memref<64x16xbf16>, + %arg1: memref<8x64x2xbf16>, %arg2: memref<64x64xbf16>) { + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1 : memref<64x16xbf16>, memref<8x64x2xbf16>) + outs(%arg2 : memref<64x64xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return +} + +// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +// +// CHECK-LABEL: func.func @non_square_vnni_gemm( +// CHECK: %[[arg0:.*]]: memref<64x16xbf16>, %[[arg1:.*]]: memref<8x64x2xbf16>, %[[arg2:.*]]: memref<64x64xbf16>) { +// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[expand_shape:.*]] = memref.expand_shape %[[arg0]] {{\[}}[0], [1, 2]] output_shape [64, 8, 2] +// CHECK %[[read0:.*]] = vector.transfer_read %[[expand_shape]][%[[c0]], %[[c0]], %[[c0]]], %cst {in_bounds = [true, true, true]} +// CHECK %[[read1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true, true]} +// CHECK: %[[read2:.*]] = vector.transfer_read %[[arg2]][%[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true]} +// CHECK: %[[read3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %[[read0]], %[[read1]], %[[read2]] +// CHECK: vector.transfer_write %3, %[[arg2]][%[[c0]], %[[c0]]] {in_bounds = [true, true]} : vector<64x64xbf16>, memref<64x64xbf16>