Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VNNI gemm vectorization fix #994

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions include/TPP/Dialect/Xsmm/XsmmUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class MemRefType;
namespace func {
class CallOp;
}
namespace linalg {
class GenericOp;
struct ContractionDimensions;
} // namespace linalg

namespace xsmm {
class UnaryKindAttr;
Expand Down Expand Up @@ -122,6 +126,19 @@ func::CallOp buildXsmmCall(RewriterBase &rewriter, XsmmCallType callType,
SmallVector<XsmmOperand> operands, TypeRange results,
FlatSymbolRefAttr fnName, Operation *parentOp,
Operation *insertBefore);

std::optional<unsigned>
getPosInCodomain(unsigned dim, linalg::GenericOp linalgOp, AffineMap map);

LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter,
linalg::GenericOp linalgOp);

FailureOr<linalg::ContractionDimensions>
inferContractionDims(linalg::GenericOp genericOp);

std::optional<unsigned> getAffineBinaryOpExprIndex(AffineMap map, int index,
MLIRContext *context);

} // namespace utils
} // namespace xsmm
} // namespace mlir
Expand Down
148 changes: 146 additions & 2 deletions lib/TPP/Dialect/Xsmm/XsmmUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,27 @@

#include "TPP/Dialect/Xsmm/XsmmUtils.h"
#include "TPP/Dialect/Xsmm/XsmmOps.h"
#include "TPP/IR/StructuredOpMatcher.h"
#include "TPP/Transforms/Utils/BuilderUtils.h"
#include "TPP/Transforms/Utils/VNNIUtils.h"
#include "TPP/Transforms/Utils/ValueUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/TypeUtilities.h"

#include "TPP/Transforms/Utils/BuilderUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "xsmm-utils"

using namespace mlir;
using namespace mlir::linalg;

namespace mlir {
namespace xsmm {
namespace utils {
Expand Down Expand Up @@ -564,6 +570,144 @@ FailureOr<SmallVector<Attribute>> getBrgemmFlags(PatternRewriter &rewriter,
return attributes;
}

static llvm::SmallDenseSet<int64_t>
findIndexingOperand(AffineMap indexingMap,
ArrayRef<mlir::utils::IteratorType> iterators,
mlir::utils::IteratorType iter) {
assert(iterators.size() == indexingMap.getNumDims());
llvm::SmallDenseSet<int64_t> res;
for (AffineExpr e : indexingMap.getResults()) {
int position = -1;
if (isa<AffineDimExpr>(e)) {
auto expr = dyn_cast<AffineDimExpr>(e);
position = expr.getPosition();
} else if (isa<AffineBinaryOpExpr>(e)) {
auto lhs = dyn_cast<AffineBinaryOpExpr>(e).getLHS();
assert(isa<AffineDimExpr>(lhs));
position = (dyn_cast<AffineDimExpr>(lhs)).getPosition();
}
assert(position >= 0);
if (iterators[position] == iter &&
llvm::count_if(indexingMap.getResults(), [position](AffineExpr e) {
return e.isFunctionOfDim(position);
}) == 1)
res.insert(position);
}
return res;
}
namespace {
auto par = mlir::utils::IteratorType::parallel;
auto red = mlir::utils::IteratorType::reduction;
} // namespace

FailureOr<linalg::ContractionDimensions>
inferContractionDims(linalg::GenericOp linalgOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized, this is a copy of the upstream code. This really should not happen.

This fix need to be upstream. Let's stop work in this PR for now, please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to copy the upstream code because the upstream code does not get the floordiv dimension if its an affinebinaryopdimexpr. I forgot to discuss this with you.

auto indexingMaps = linalgOp.getIndexingMapsArray();
auto iterators = linalgOp.getIteratorTypesArray();
llvm::SmallDenseSet<int64_t> a =
findIndexingOperand(indexingMaps[0], iterators, par);
llvm::SmallDenseSet<int64_t> b =
findIndexingOperand(indexingMaps[1], iterators, par);
llvm::SmallDenseSet<int64_t> c =
findIndexingOperand(indexingMaps[2], iterators, par);

// A & C - B are the iterators involved in an outer-product along A (the LHS).
llvm::SmallDenseSet<int64_t> ac = a;
llvm::set_intersect(ac, c);
llvm::set_subtract(ac, b);
// B & C - A are the iterators involved in an outer-product along B (the RHS).
llvm::SmallDenseSet<int64_t> bc = b;
llvm::set_intersect(bc, c);
llvm::set_subtract(bc, a);
// A & B & C are the "batch" dimensions.
llvm::SmallDenseSet<int64_t> batches = a;
llvm::set_intersect(batches, b);
llvm::set_intersect(batches, c);

// A & B red are the reduction dimensions.
llvm::SmallDenseSet<int64_t> ra =
findIndexingOperand(indexingMaps[0], iterators, red);
llvm::SmallDenseSet<int64_t> rb =
findIndexingOperand(indexingMaps[1], iterators, red);
llvm::set_intersect(ra, rb);

// Return each set in sorted order.
ContractionDimensions dimensions{
SmallVector<unsigned, 2>(batches.begin(), batches.end()),
SmallVector<unsigned, 2>(ac.begin(), ac.end()),
SmallVector<unsigned, 2>(bc.begin(), bc.end()),
SmallVector<unsigned, 2>(ra.begin(), ra.end())};
llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
llvm::sort(dimensions.m.begin(), dimensions.m.end());
llvm::sort(dimensions.n.begin(), dimensions.n.end());
llvm::sort(dimensions.k.begin(), dimensions.k.end());
return dimensions;
}

std::optional<unsigned> getAffineBinaryOpExprIndex(AffineMap map, int index,
MLIRContext *context) {
for (unsigned i = 0; i < map.getNumResults(); i++) {
auto result = map.getResult(i);
if (isa<AffineBinaryOpExpr>(result) &&
dyn_cast<AffineBinaryOpExpr>(result).getLHS() ==
getAffineDimExpr(index, context)) {
return i;
}
}
llvm_unreachable("invalid binary op index");
}

LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it also validate that the innermost dimension is a valid VNNI factor?

linalg::GenericOp linalgOp) {
if (linalgOp->getNumOperands() != 3)
return failure();

if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) !=
xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16)) {
return failure();
}
auto iteratorTypes = linalgOp.getIteratorTypesArray();
if (iteratorTypes.size() < 4)
return failure();

auto contractionDims = inferContractionDims(linalgOp);
if (failed(contractionDims))
return failure();

unsigned m = contractionDims->m.back();
unsigned n = contractionDims->n.back();

// m and n dimensions must be parallel dimensions
if (!linalg::isParallelIterator(iteratorTypes[m]) ||
!linalg::isParallelIterator(iteratorTypes[n])) {
return failure();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Isn't this guaranteed by inference logic? That is, I think M and N are always parallel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, will remove


// innermost dimension must be a reduction dimension for VNNI type operations
if (!linalg::isReductionIterator(iteratorTypes[iteratorTypes.size() - 1])) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VNNI must have two reductions, no? Why are you not checking in the same way as m and n?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am, the thing is contraction dims infer api does not return vnni dimension as one of the ks and thats why I need to check the innermost dimension as a reduction seperately, followed by the k as a floordiv, which is what I do here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the dim infer should check VNNI directly? This is confusing.

return failure();
}

// get the index of the iterator corresponding to the floordiv operation
auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it just fail on size 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On size 0, we still need to get vnni dimension, which we don't get currently as we don't have the reduction dimension present in both M and N maps. Since we know there's an inner reduction due to the previous check, we can safely say there's a 0th dimension corresponding to the k here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right vnni dim won't show up at all.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait I'm a bit lost now, wasn't this change suppose to catch the floordiv dim?

Copy link
Contributor

@rengolin rengolin Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's true. Just checking the last dim reduction isn't enough, as this could just be the k reduction and a non-VNNI bf16 matmul.

I am afraid I didn't understand, I am checking for both last dim reduction as well as k reduction, but if there's no k dimension, it could still be a VNNI gemm, which is ensured by the previous check. It's never a k reduction only, because of the previous check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait I'm a bit lost now, wasn't this change suppose to catch the floordiv dim?

Sorry I realize now the 0 assignment wasn't necessary. I have removed that.

auto map1 = linalgOp.getIndexingMapsArray()[1];
auto index = getAffineBinaryOpExprIndex(map1, k, linalgOp.getContext());
if (!index)
return failure();

// Ensure that the body of the generic operation is mul-add chain
// clang-format off
using namespace mlir::structured_match;
auto hasRightOpChain =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd call it isMatmulChain, because "right" is subjective and non descriptive.

StructuredOpMatcher::make<linalg::GenericOp>()
.region(MatchOne(0), WithOpChain<KindMul, KindAdd>(
/*captures=*/nullptr));
// clang-format on
if (!hasRightOpChain.match(linalgOp))
return failure();
return success();
}

template FailureOr<SmallVector<Attribute>>
getBrgemmFlags<xsmm::BrgemmDispatchOp>(PatternRewriter &rewriter,
xsmm::BrgemmDispatchOp dispatchOpTy,
Expand Down
22 changes: 14 additions & 8 deletions lib/TPP/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_VECTORIZATIONPASS
Expand All @@ -39,11 +38,8 @@ struct LinalgGenericToVector : OpRewritePattern<linalg::GenericOp> {
PatternRewriter &rewriter) const override {
if (!linalgOp.hasPureBufferSemantics())
return failure();
if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) ==
xsmm::DataTypeAttr::get(rewriter.getContext(),
xsmm::DataType::BF16) &&
linalgOp.getIteratorTypes().size() >= 5 &&
linalgOp.getNumOperands() == 3) {
auto check = xsmm::utils::checkVNNIGemmStructure(rewriter, linalgOp);
if (succeeded(check)) {
SmallVector<int64_t> shape;
SmallVector<ReassociationIndices> indices;
int index = 0;
Expand Down Expand Up @@ -72,14 +68,24 @@ struct LinalgGenericToVector : OpRewritePattern<linalg::GenericOp> {
}
auto map0 = linalgOp.getIndexingMapsArray()[0];
auto map1 = linalgOp.getIndexingMapsArray()[1];
map0 = map0.insertResult(map1.getResult(map1.getNumResults() - 1), 3);
int map1Index = map1.getNumResults() - 3;
// Set the innermost dimension of the first map to vnni dimension
map0 = map0.insertResult(map1.getResult(map1.getNumResults() - 1),
map0.getNumResults());
auto contractionDims = xsmm::utils::inferContractionDims(linalgOp);

auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0;
// Get the index of the iterator corresponding to the floordiv operation
auto map1Index = *xsmm::utils::getAffineBinaryOpExprIndex(
map1, k, linalgOp.getContext());

AffineExpr expr = map1.getResult(map1Index);
if (isa<AffineBinaryOpExpr>(expr)) {

auto expand = rewriter.create<memref::ExpandShapeOp>(
linalgOp.getLoc(), shape, linalgOp.getOperand(0), indices);
linalgOp.setOperand(0, expand.getResult());
// Replace the floordiv operation with just the LHS of the floordiv
// expression
map1 = map1.insertResult(
dyn_cast<AffineBinaryOpExpr>(map1.getResult(map1Index)).getLHS(),
map1Index + 1);
Expand Down
72 changes: 72 additions & 0 deletions test/Passes/pass-vectorization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,75 @@ module {
// CHECK-NOT: %[[vec3:.*]] = vector.transfer_read
// CHECK-NOT: %[[vec4:.*]] = vector.contract
// CHECK-NOT: vector.transfer_write %[[vec4]]

// -----

#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3 floordiv 2, d2, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>

func.func @vnni_brgemm_strided(%arg0: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>,
%arg1: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>,
%arg2: memref<8x8xbf16>) {
linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]}
ins(%arg0, %arg1 : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>)
outs(%arg2 : memref<8x8xbf16>) {
^bb0(%in: bf16, %in_9: bf16, %out: bf16):
%11 = arith.mulf %in, %in_9 : bf16
%12 = arith.addf %out, %11 : bf16
linalg.yield %12 : bf16
}
return
}
// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)>
// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
//
// CHECK-LABEL: func.func @vnni_brgemm_strided(
// CHECK: %[[arg0:.*]]: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, %[[arg1:.*]]: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>, %[[arg2:.*]]: memref<8x8xbf16>) {
// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : bf16
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[expand_shape:.*]] = memref.expand_shape %[[arg0]] {{\[}}[0], [1], [2, 3]] output_shape [8, 8, 4, 2] : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>> into memref<8x8x4x2xbf16, strided<[64, 8, 2, 1], offset: ?>>
// CHECK: %[[read0:.*]] = vector.transfer_read %[[expand_shape]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true, true, true]}
// CHECK: %[[read1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %cst {in_bounds = [true, true, true, true]}
// CHECK: %[[read2:.*]] = vector.transfer_read %[[arg2]][%[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true]}
// CHECK: %[[read3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<add>} %[[read0]], %[[read1]], %[[read2]] : vector<8x8x4x2xbf16>, vector<8x4x8x2xbf16> into vector<8x8xbf16>
// CHECK: vector.transfer_write %[[read3]], %[[arg2]][%[[c0]], %[[c0]]] {in_bounds = [true, true]}

// -----

#map = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d3 floordiv 2, d2, d0)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>

func.func @non_square_vnni_gemm(%arg0: memref<64x16xbf16>,
%arg1: memref<8x64x2xbf16>, %arg2: memref<64x64xbf16>) {
linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : memref<64x16xbf16>, memref<8x64x2xbf16>)
outs(%arg2 : memref<64x64xbf16>) {
^bb0(%in: bf16, %in_2: bf16, %out: bf16):
%1 = arith.mulf %in, %in_2 : bf16
%2 = arith.addf %out, %1 : bf16
linalg.yield %2 : bf16
}
return
}

// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
// CHECK: #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
//
// CHECK-LABEL: func.func @non_square_vnni_gemm(
// CHECK: %[[arg0:.*]]: memref<64x16xbf16>, %[[arg1:.*]]: memref<8x64x2xbf16>, %[[arg2:.*]]: memref<64x64xbf16>) {
// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : bf16
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[expand_shape:.*]] = memref.expand_shape %[[arg0]] {{\[}}[0], [1, 2]] output_shape [64, 8, 2]
// CHECK %[[read0:.*]] = vector.transfer_read %[[expand_shape]][%[[c0]], %[[c0]], %[[c0]]], %cst {in_bounds = [true, true, true]}
// CHECK %[[read1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true, true]}
// CHECK: %[[read2:.*]] = vector.transfer_read %[[arg2]][%[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true]}
// CHECK: %[[read3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[read0]], %[[read1]], %[[read2]]
// CHECK: vector.transfer_write %3, %[[arg2]][%[[c0]], %[[c0]]] {in_bounds = [true, true]} : vector<64x64xbf16>, memref<64x64xbf16>