From 8441ee0c021037bc5762565dc6e0da735a33c3ef Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Thu, 5 Dec 2024 18:23:49 -0800 Subject: [PATCH 1/4] VNNI gemm fix --- lib/TPP/Transforms/Vectorization.cpp | 5 +- test/Passes/pass-vectorization.mlir | 72 ++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/lib/TPP/Transforms/Vectorization.cpp b/lib/TPP/Transforms/Vectorization.cpp index 38406f0f9..0073da373 100644 --- a/lib/TPP/Transforms/Vectorization.cpp +++ b/lib/TPP/Transforms/Vectorization.cpp @@ -42,7 +42,7 @@ struct LinalgGenericToVector : OpRewritePattern { if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) == xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16) && - linalgOp.getIteratorTypes().size() >= 5 && + linalgOp.getIteratorTypes().size() >= 4 && linalgOp.getNumOperands() == 3) { SmallVector shape; SmallVector indices; @@ -72,7 +72,8 @@ struct LinalgGenericToVector : OpRewritePattern { } auto map0 = linalgOp.getIndexingMapsArray()[0]; auto map1 = linalgOp.getIndexingMapsArray()[1]; - map0 = map0.insertResult(map1.getResult(map1.getNumResults() - 1), 3); + map0 = map0.insertResult(map1.getResult(map1.getNumResults() - 1), + map0.getNumResults()); int map1Index = map1.getNumResults() - 3; AffineExpr expr = map1.getResult(map1Index); if (isa(expr)) { diff --git a/test/Passes/pass-vectorization.mlir b/test/Passes/pass-vectorization.mlir index c99fc6d32..257a21c8f 100644 --- a/test/Passes/pass-vectorization.mlir +++ b/test/Passes/pass-vectorization.mlir @@ -110,3 +110,75 @@ module { // CHECK-NOT: %[[vec3:.*]] = vector.transfer_read // CHECK-NOT: %[[vec4:.*]] = vector.contract // CHECK-NOT: vector.transfer_write %[[vec4]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3 floordiv 2, d2, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)> + +func.func @vnni_brgemm_strided(%arg0: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, + %arg1: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>, + %arg2: memref<8x8xbf16>) { + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%arg0, %arg1 : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>) + outs(%arg2 : memref<8x8xbf16>) { + ^bb0(%in: bf16, %in_9: bf16, %out: bf16): + %11 = arith.mulf %in, %in_9 : bf16 + %12 = arith.addf %out, %11 : bf16 + linalg.yield %12 : bf16 + } + return +} +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)> +// +// CHECK-LABEL: func.func @vnni_brgemm_strided( +// CHECK: %[[arg0:.*]]: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, %[[arg1:.*]]: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>, %[[arg2:.*]]: memref<8x8xbf16>) { +// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[expand_shape:.*]] = memref.expand_shape %[[arg0]] {{\[}}[0], [1], [2, 3]] output_shape [8, 8, 4, 2] : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>> into memref<8x8x4x2xbf16, strided<[64, 8, 2, 1], offset: ?>> +// CHECK: %[[read0:.*]] = vector.transfer_read %[[expand_shape]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true, true, true]} +// CHECK: %[[read1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %cst {in_bounds = [true, true, true, true]} +// CHECK: %[[read2:.*]] = vector.transfer_read %[[arg2]][%[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true]} +// CHECK: %[[read3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} %[[read0]], %[[read1]], %[[read2]] : vector<8x8x4x2xbf16>, vector<8x4x8x2xbf16> into vector<8x8xbf16> +// CHECK: vector.transfer_write %[[read3]], %[[arg2]][%[[c0]], %[[c0]]] {in_bounds = [true, true]} + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d3 floordiv 2, d2, d0)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +func.func @non_square_vnni_gemm(%arg0: memref<64x16xbf16>, + %arg1: memref<8x64x2xbf16>, %arg2: memref<64x64xbf16>) { + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1 : memref<64x16xbf16>, memref<8x64x2xbf16>) + outs(%arg2 : memref<64x64xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return +} + +// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +// +// CHECK-LABEL: func.func @non_square_vnni_gemm( +// CHECK: %[[arg0:.*]]: memref<64x16xbf16>, %[[arg1:.*]]: memref<8x64x2xbf16>, %[[arg2:.*]]: memref<64x64xbf16>) { +// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[expand_shape:.*]] = memref.expand_shape %[[arg0]] {{\[}}[0], [1, 2]] output_shape [64, 8, 2] +// CHECK %[[read0:.*]] = vector.transfer_read %[[expand_shape]][%[[c0]], %[[c0]], %[[c0]]], %cst {in_bounds = [true, true, true]} +// CHECK %[[read1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true, true]} +// CHECK: %[[read2:.*]] = vector.transfer_read %[[arg2]][%[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true]} +// CHECK: %[[read3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %[[read0]], %[[read1]], %[[read2]] +// CHECK: vector.transfer_write %3, %[[arg2]][%[[c0]], %[[c0]]] {in_bounds = [true, true]} : vector<64x64xbf16>, memref<64x64xbf16> From c614f3ff18dc6ec253422fc857308a9b6a685586 Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Wed, 11 Dec 2024 02:09:00 -0800 Subject: [PATCH 2/4] Validation updated for random order of iterators --- include/TPP/Dialect/Xsmm/XsmmUtils.h | 17 ++++ lib/TPP/Dialect/Xsmm/XsmmUtils.cpp | 144 ++++++++++++++++++++++++++- lib/TPP/Transforms/Vectorization.cpp | 15 +-- 3 files changed, 167 insertions(+), 9 deletions(-) diff --git a/include/TPP/Dialect/Xsmm/XsmmUtils.h b/include/TPP/Dialect/Xsmm/XsmmUtils.h index 65784ba4b..ddef9356d 100644 --- a/include/TPP/Dialect/Xsmm/XsmmUtils.h +++ b/include/TPP/Dialect/Xsmm/XsmmUtils.h @@ -26,6 +26,10 @@ class MemRefType; namespace func { class CallOp; } +namespace linalg { +class GenericOp; +struct ContractionDimensions; +} // namespace linalg namespace xsmm { class UnaryKindAttr; @@ -122,6 +126,19 @@ func::CallOp buildXsmmCall(RewriterBase &rewriter, XsmmCallType callType, SmallVector operands, TypeRange results, FlatSymbolRefAttr fnName, Operation *parentOp, Operation *insertBefore); + +std::optional +getPosInCodomain(unsigned dim, linalg::GenericOp linalgOp, AffineMap map); + +LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter, + linalg::GenericOp linalgOp); + +FailureOr +inferContractionDims(linalg::GenericOp genericOp); + +std::optional getAffineBinaryOpExprIndex(AffineMap map, int index, + MLIRContext *context); + } // namespace utils } // namespace xsmm } // namespace mlir diff --git a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp index 7a070f928..70a51d99b 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp @@ -8,21 +8,27 @@ #include "TPP/Dialect/Xsmm/XsmmUtils.h" #include "TPP/Dialect/Xsmm/XsmmOps.h" +#include "TPP/IR/StructuredOpMatcher.h" +#include "TPP/Transforms/Utils/BuilderUtils.h" #include "TPP/Transforms/Utils/VNNIUtils.h" #include "TPP/Transforms/Utils/ValueUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/TypeUtilities.h" - -#include "TPP/Transforms/Utils/BuilderUtils.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "xsmm-utils" +using namespace mlir; +using namespace mlir::linalg; + namespace mlir { namespace xsmm { namespace utils { @@ -564,6 +570,140 @@ FailureOr> getBrgemmFlags(PatternRewriter &rewriter, return attributes; } +static llvm::SmallDenseSet +findIndexingOperand(AffineMap indexingMap, + ArrayRef iterators, + mlir::utils::IteratorType iter) { + assert(iterators.size() == indexingMap.getNumDims()); + llvm::SmallDenseSet res; + for (AffineExpr e : indexingMap.getResults()) { + int position = -1; + if (isa(e)) { + auto expr = dyn_cast(e); + position = expr.getPosition(); + } else if (isa(e)) { + auto lhs = dyn_cast(e).getLHS(); + assert(isa(lhs)); + position = (dyn_cast(lhs)).getPosition(); + } + assert(position >= 0); + if (iterators[position] == iter && + llvm::count_if(indexingMap.getResults(), [position](AffineExpr e) { + return e.isFunctionOfDim(position); + }) == 1) + res.insert(position); + } + return res; +} +namespace { +auto par = mlir::utils::IteratorType::parallel; +auto red = mlir::utils::IteratorType::reduction; +} // namespace + +FailureOr +inferContractionDims(linalg::GenericOp linalgOp) { + auto indexingMaps = linalgOp.getIndexingMapsArray(); + auto iterators = linalgOp.getIteratorTypesArray(); + llvm::SmallDenseSet a = + findIndexingOperand(indexingMaps[0], iterators, par); + llvm::SmallDenseSet b = + findIndexingOperand(indexingMaps[1], iterators, par); + llvm::SmallDenseSet c = + findIndexingOperand(indexingMaps[2], iterators, par); + + // A & C - B are the iterators involved in an outer-product along A (the LHS). + llvm::SmallDenseSet ac = a; + llvm::set_intersect(ac, c); + llvm::set_subtract(ac, b); + // B & C - A are the iterators involved in an outer-product along B (the RHS). + llvm::SmallDenseSet bc = b; + llvm::set_intersect(bc, c); + llvm::set_subtract(bc, a); + // A & B & C are the "batch" dimensions. + llvm::SmallDenseSet batches = a; + llvm::set_intersect(batches, b); + llvm::set_intersect(batches, c); + + // A & B red are the reduction dimensions. + llvm::SmallDenseSet ra = + findIndexingOperand(indexingMaps[0], iterators, red); + llvm::SmallDenseSet rb = + findIndexingOperand(indexingMaps[1], iterators, red); + llvm::set_intersect(ra, rb); + + // Return each set in sorted order. + ContractionDimensions dimensions{ + SmallVector(batches.begin(), batches.end()), + SmallVector(ac.begin(), ac.end()), + SmallVector(bc.begin(), bc.end()), + SmallVector(ra.begin(), ra.end())}; + llvm::sort(dimensions.batch.begin(), dimensions.batch.end()); + llvm::sort(dimensions.m.begin(), dimensions.m.end()); + llvm::sort(dimensions.n.begin(), dimensions.n.end()); + llvm::sort(dimensions.k.begin(), dimensions.k.end()); + return dimensions; +} + +std::optional getAffineBinaryOpExprIndex(AffineMap map, int index, + MLIRContext *context) { + for (unsigned i = 0; i < map.getNumResults(); i++) { + auto result = map.getResult(i); + if (isa(result) && + dyn_cast(result).getLHS() == + getAffineDimExpr(index, context)) { + return i; + } + } + llvm_unreachable("invalid binary op index"); +} + +LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter, + linalg::GenericOp linalgOp) { + if (linalgOp->getNumOperands() != 3) + return failure(); + + if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) != + xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16)) { + return failure(); + } + auto iteratorTypes = linalgOp.getIteratorTypesArray(); + if (iteratorTypes.size() < 4) + return failure(); + + auto contractionDims = inferContractionDims(linalgOp); + if (failed(contractionDims)) + return failure(); + + unsigned m = contractionDims->m.back(); + unsigned n = contractionDims->n.back(); + + if (!linalg::isParallelIterator(iteratorTypes[m]) || + !linalg::isParallelIterator(iteratorTypes[n])) { + return failure(); + } + + if (!linalg::isReductionIterator(iteratorTypes[iteratorTypes.size() - 1])) { + return failure(); + } + + auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0; + auto map1 = linalgOp.getIndexingMapsArray()[1]; + auto index = getAffineBinaryOpExprIndex(map1, k, linalgOp.getContext()); + if (!index) + return failure(); + + // clang-format off + using namespace mlir::structured_match; + auto hasRightOpChain = + StructuredOpMatcher::make() + .region(MatchOne(0), WithOpChain( + /*captures=*/nullptr)); + // clang-format on + if (!hasRightOpChain.match(linalgOp)) + return failure(); + return success(); +} + template FailureOr> getBrgemmFlags(PatternRewriter &rewriter, xsmm::BrgemmDispatchOp dispatchOpTy, diff --git a/lib/TPP/Transforms/Vectorization.cpp b/lib/TPP/Transforms/Vectorization.cpp index 0073da373..857d85826 100644 --- a/lib/TPP/Transforms/Vectorization.cpp +++ b/lib/TPP/Transforms/Vectorization.cpp @@ -18,7 +18,6 @@ #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" - namespace mlir { namespace tpp { #define GEN_PASS_DEF_VECTORIZATIONPASS @@ -39,11 +38,8 @@ struct LinalgGenericToVector : OpRewritePattern { PatternRewriter &rewriter) const override { if (!linalgOp.hasPureBufferSemantics()) return failure(); - if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) == - xsmm::DataTypeAttr::get(rewriter.getContext(), - xsmm::DataType::BF16) && - linalgOp.getIteratorTypes().size() >= 4 && - linalgOp.getNumOperands() == 3) { + auto check = xsmm::utils::checkVNNIGemmStructure(rewriter, linalgOp); + if (succeeded(check)) { SmallVector shape; SmallVector indices; int index = 0; @@ -74,7 +70,12 @@ struct LinalgGenericToVector : OpRewritePattern { auto map1 = linalgOp.getIndexingMapsArray()[1]; map0 = map0.insertResult(map1.getResult(map1.getNumResults() - 1), map0.getNumResults()); - int map1Index = map1.getNumResults() - 3; + auto contractionDims = xsmm::utils::inferContractionDims(linalgOp); + + auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0; + auto map1Index = *xsmm::utils::getAffineBinaryOpExprIndex( + map1, k, linalgOp.getContext()); + AffineExpr expr = map1.getResult(map1Index); if (isa(expr)) { From a2bb3064ec342f13e67b4521520cca7497ea565e Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Wed, 11 Dec 2024 02:54:36 -0800 Subject: [PATCH 3/4] Added comments --- lib/TPP/Dialect/Xsmm/XsmmUtils.cpp | 4 ++++ lib/TPP/Transforms/Vectorization.cpp | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp index 70a51d99b..039d70d7f 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp @@ -677,21 +677,25 @@ LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter, unsigned m = contractionDims->m.back(); unsigned n = contractionDims->n.back(); + // m and n dimensions must be parallel dimensions if (!linalg::isParallelIterator(iteratorTypes[m]) || !linalg::isParallelIterator(iteratorTypes[n])) { return failure(); } + // innermost dimension must be a reduction dimension for VNNI type operations if (!linalg::isReductionIterator(iteratorTypes[iteratorTypes.size() - 1])) { return failure(); } + // get the index of the iterator corresponding to the floordiv operation auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0; auto map1 = linalgOp.getIndexingMapsArray()[1]; auto index = getAffineBinaryOpExprIndex(map1, k, linalgOp.getContext()); if (!index) return failure(); + // Ensure that the body of the generic operation is mul-add chain // clang-format off using namespace mlir::structured_match; auto hasRightOpChain = diff --git a/lib/TPP/Transforms/Vectorization.cpp b/lib/TPP/Transforms/Vectorization.cpp index 857d85826..9f06fc6a1 100644 --- a/lib/TPP/Transforms/Vectorization.cpp +++ b/lib/TPP/Transforms/Vectorization.cpp @@ -68,11 +68,13 @@ struct LinalgGenericToVector : OpRewritePattern { } auto map0 = linalgOp.getIndexingMapsArray()[0]; auto map1 = linalgOp.getIndexingMapsArray()[1]; + // Set the innermost dimension of the first map to vnni dimension map0 = map0.insertResult(map1.getResult(map1.getNumResults() - 1), map0.getNumResults()); auto contractionDims = xsmm::utils::inferContractionDims(linalgOp); auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0; + // Get the index of the iterator corresponding to the floordiv operation auto map1Index = *xsmm::utils::getAffineBinaryOpExprIndex( map1, k, linalgOp.getContext()); @@ -82,6 +84,8 @@ struct LinalgGenericToVector : OpRewritePattern { auto expand = rewriter.create( linalgOp.getLoc(), shape, linalgOp.getOperand(0), indices); linalgOp.setOperand(0, expand.getResult()); + // Replace the floordiv operation with just the LHS of the floordiv + // expression map1 = map1.insertResult( dyn_cast(map1.getResult(map1Index)).getLHS(), map1Index + 1); From 7d4bea76440f380aeb7a24176b94f64c29c2b1c3 Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Wed, 11 Dec 2024 03:55:53 -0800 Subject: [PATCH 4/4] Review comments --- lib/TPP/Dialect/Xsmm/XsmmUtils.cpp | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp index 039d70d7f..62c01a9cc 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp @@ -674,22 +674,13 @@ LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter, if (failed(contractionDims)) return failure(); - unsigned m = contractionDims->m.back(); - unsigned n = contractionDims->n.back(); - - // m and n dimensions must be parallel dimensions - if (!linalg::isParallelIterator(iteratorTypes[m]) || - !linalg::isParallelIterator(iteratorTypes[n])) { - return failure(); - } - // innermost dimension must be a reduction dimension for VNNI type operations if (!linalg::isReductionIterator(iteratorTypes[iteratorTypes.size() - 1])) { return failure(); } // get the index of the iterator corresponding to the floordiv operation - auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0; + auto k = contractionDims->k.back(); auto map1 = linalgOp.getIndexingMapsArray()[1]; auto index = getAffineBinaryOpExprIndex(map1, k, linalgOp.getContext()); if (!index) @@ -698,12 +689,12 @@ LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter, // Ensure that the body of the generic operation is mul-add chain // clang-format off using namespace mlir::structured_match; - auto hasRightOpChain = + auto isMatmulChain = StructuredOpMatcher::make() .region(MatchOne(0), WithOpChain( /*captures=*/nullptr)); // clang-format on - if (!hasRightOpChain.match(linalgOp)) + if (!isMatmulChain.match(linalgOp)) return failure(); return success(); }