Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VNNI gemm vectorization fix #994

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions include/TPP/Dialect/Xsmm/XsmmUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class MemRefType;
namespace func {
class CallOp;
}
namespace linalg {
class GenericOp;
struct ContractionDimensions;
} // namespace linalg

namespace xsmm {
class UnaryKindAttr;
Expand Down Expand Up @@ -122,6 +126,19 @@ func::CallOp buildXsmmCall(RewriterBase &rewriter, XsmmCallType callType,
SmallVector<XsmmOperand> operands, TypeRange results,
FlatSymbolRefAttr fnName, Operation *parentOp,
Operation *insertBefore);

std::optional<unsigned>
getPosInCodomain(unsigned dim, linalg::GenericOp linalgOp, AffineMap map);

LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter,
linalg::GenericOp linalgOp);

FailureOr<linalg::ContractionDimensions>
inferContractionDims(linalg::GenericOp genericOp);

std::optional<unsigned> getAffineBinaryOpExprIndex(AffineMap map, int index,
MLIRContext *context);

} // namespace utils
} // namespace xsmm
} // namespace mlir
Expand Down
139 changes: 137 additions & 2 deletions lib/TPP/Dialect/Xsmm/XsmmUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,27 @@

#include "TPP/Dialect/Xsmm/XsmmUtils.h"
#include "TPP/Dialect/Xsmm/XsmmOps.h"
#include "TPP/IR/StructuredOpMatcher.h"
#include "TPP/Transforms/Utils/BuilderUtils.h"
#include "TPP/Transforms/Utils/VNNIUtils.h"
#include "TPP/Transforms/Utils/ValueUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/TypeUtilities.h"

#include "TPP/Transforms/Utils/BuilderUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "xsmm-utils"

using namespace mlir;
using namespace mlir::linalg;

namespace mlir {
namespace xsmm {
namespace utils {
Expand Down Expand Up @@ -564,6 +570,135 @@ FailureOr<SmallVector<Attribute>> getBrgemmFlags(PatternRewriter &rewriter,
return attributes;
}

static llvm::SmallDenseSet<int64_t>
findIndexingOperand(AffineMap indexingMap,
ArrayRef<mlir::utils::IteratorType> iterators,
mlir::utils::IteratorType iter) {
assert(iterators.size() == indexingMap.getNumDims());
llvm::SmallDenseSet<int64_t> res;
for (AffineExpr e : indexingMap.getResults()) {
int position = -1;
if (isa<AffineDimExpr>(e)) {
auto expr = dyn_cast<AffineDimExpr>(e);
position = expr.getPosition();
} else if (isa<AffineBinaryOpExpr>(e)) {
auto lhs = dyn_cast<AffineBinaryOpExpr>(e).getLHS();
assert(isa<AffineDimExpr>(lhs));
position = (dyn_cast<AffineDimExpr>(lhs)).getPosition();
}
assert(position >= 0);
if (iterators[position] == iter &&
llvm::count_if(indexingMap.getResults(), [position](AffineExpr e) {
return e.isFunctionOfDim(position);
}) == 1)
res.insert(position);
}
return res;
}
namespace {
auto par = mlir::utils::IteratorType::parallel;
auto red = mlir::utils::IteratorType::reduction;
} // namespace

FailureOr<linalg::ContractionDimensions>
inferContractionDims(linalg::GenericOp linalgOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized, this is a copy of the upstream code. This really should not happen.

This fix need to be upstream. Let's stop work in this PR for now, please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to copy the upstream code because the upstream code does not get the floordiv dimension if its an affinebinaryopdimexpr. I forgot to discuss this with you.

auto indexingMaps = linalgOp.getIndexingMapsArray();
auto iterators = linalgOp.getIteratorTypesArray();
llvm::SmallDenseSet<int64_t> a =
findIndexingOperand(indexingMaps[0], iterators, par);
llvm::SmallDenseSet<int64_t> b =
findIndexingOperand(indexingMaps[1], iterators, par);
llvm::SmallDenseSet<int64_t> c =
findIndexingOperand(indexingMaps[2], iterators, par);

// A & C - B are the iterators involved in an outer-product along A (the LHS).
llvm::SmallDenseSet<int64_t> ac = a;
llvm::set_intersect(ac, c);
llvm::set_subtract(ac, b);
// B & C - A are the iterators involved in an outer-product along B (the RHS).
llvm::SmallDenseSet<int64_t> bc = b;
llvm::set_intersect(bc, c);
llvm::set_subtract(bc, a);
// A & B & C are the "batch" dimensions.
llvm::SmallDenseSet<int64_t> batches = a;
llvm::set_intersect(batches, b);
llvm::set_intersect(batches, c);

// A & B red are the reduction dimensions.
llvm::SmallDenseSet<int64_t> ra =
findIndexingOperand(indexingMaps[0], iterators, red);
llvm::SmallDenseSet<int64_t> rb =
findIndexingOperand(indexingMaps[1], iterators, red);
llvm::set_intersect(ra, rb);

// Return each set in sorted order.
ContractionDimensions dimensions{
SmallVector<unsigned, 2>(batches.begin(), batches.end()),
SmallVector<unsigned, 2>(ac.begin(), ac.end()),
SmallVector<unsigned, 2>(bc.begin(), bc.end()),
SmallVector<unsigned, 2>(ra.begin(), ra.end())};
llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
llvm::sort(dimensions.m.begin(), dimensions.m.end());
llvm::sort(dimensions.n.begin(), dimensions.n.end());
llvm::sort(dimensions.k.begin(), dimensions.k.end());
return dimensions;
}

std::optional<unsigned> getAffineBinaryOpExprIndex(AffineMap map, int index,
MLIRContext *context) {
for (unsigned i = 0; i < map.getNumResults(); i++) {
auto result = map.getResult(i);
if (isa<AffineBinaryOpExpr>(result) &&
dyn_cast<AffineBinaryOpExpr>(result).getLHS() ==
getAffineDimExpr(index, context)) {
return i;
}
}
llvm_unreachable("invalid binary op index");
}

LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it also validate that the innermost dimension is a valid VNNI factor?

linalg::GenericOp linalgOp) {
if (linalgOp->getNumOperands() != 3)
return failure();

if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) !=
xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16)) {
return failure();
}
auto iteratorTypes = linalgOp.getIteratorTypesArray();
if (iteratorTypes.size() < 4)
return failure();

auto contractionDims = inferContractionDims(linalgOp);
if (failed(contractionDims))
return failure();

// innermost dimension must be a reduction dimension for VNNI type operations
if (!linalg::isReductionIterator(iteratorTypes[iteratorTypes.size() - 1])) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VNNI must have two reductions, no? Why are you not checking in the same way as m and n?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am, the thing is contraction dims infer api does not return vnni dimension as one of the ks and thats why I need to check the innermost dimension as a reduction seperately, followed by the k as a floordiv, which is what I do here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the dim infer should check VNNI directly? This is confusing.

return failure();
}

// get the index of the iterator corresponding to the floordiv operation
auto k = contractionDims->k.back();
auto map1 = linalgOp.getIndexingMapsArray()[1];
auto index = getAffineBinaryOpExprIndex(map1, k, linalgOp.getContext());
if (!index)
return failure();

// Ensure that the body of the generic operation is mul-add chain
// clang-format off
using namespace mlir::structured_match;
auto isMatmulChain =
StructuredOpMatcher::make<linalg::GenericOp>()
.region(MatchOne(0), WithOpChain<KindMul, KindAdd>(
/*captures=*/nullptr));
// clang-format on
if (!isMatmulChain.match(linalgOp))
return failure();
return success();
}

template FailureOr<SmallVector<Attribute>>
getBrgemmFlags<xsmm::BrgemmDispatchOp>(PatternRewriter &rewriter,
xsmm::BrgemmDispatchOp dispatchOpTy,
Expand Down
22 changes: 14 additions & 8 deletions lib/TPP/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_VECTORIZATIONPASS
Expand All @@ -39,11 +38,8 @@ struct LinalgGenericToVector : OpRewritePattern<linalg::GenericOp> {
PatternRewriter &rewriter) const override {
if (!linalgOp.hasPureBufferSemantics())
return failure();
if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) ==
xsmm::DataTypeAttr::get(rewriter.getContext(),
xsmm::DataType::BF16) &&
linalgOp.getIteratorTypes().size() >= 5 &&
linalgOp.getNumOperands() == 3) {
auto check = xsmm::utils::checkVNNIGemmStructure(rewriter, linalgOp);
if (succeeded(check)) {
SmallVector<int64_t> shape;
SmallVector<ReassociationIndices> indices;
int index = 0;
Expand Down Expand Up @@ -72,14 +68,24 @@ struct LinalgGenericToVector : OpRewritePattern<linalg::GenericOp> {
}
auto map0 = linalgOp.getIndexingMapsArray()[0];
auto map1 = linalgOp.getIndexingMapsArray()[1];
map0 = map0.insertResult(map1.getResult(map1.getNumResults() - 1), 3);
int map1Index = map1.getNumResults() - 3;
// Set the innermost dimension of the first map to vnni dimension
map0 = map0.insertResult(map1.getResult(map1.getNumResults() - 1),
map0.getNumResults());
auto contractionDims = xsmm::utils::inferContractionDims(linalgOp);

auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0;
// Get the index of the iterator corresponding to the floordiv operation
auto map1Index = *xsmm::utils::getAffineBinaryOpExprIndex(
map1, k, linalgOp.getContext());

AffineExpr expr = map1.getResult(map1Index);
if (isa<AffineBinaryOpExpr>(expr)) {

auto expand = rewriter.create<memref::ExpandShapeOp>(
linalgOp.getLoc(), shape, linalgOp.getOperand(0), indices);
linalgOp.setOperand(0, expand.getResult());
// Replace the floordiv operation with just the LHS of the floordiv
// expression
map1 = map1.insertResult(
dyn_cast<AffineBinaryOpExpr>(map1.getResult(map1Index)).getLHS(),
map1Index + 1);
Expand Down
72 changes: 72 additions & 0 deletions test/Passes/pass-vectorization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,75 @@ module {
// CHECK-NOT: %[[vec3:.*]] = vector.transfer_read
// CHECK-NOT: %[[vec4:.*]] = vector.contract
// CHECK-NOT: vector.transfer_write %[[vec4]]

// -----

#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3 floordiv 2, d2, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>

func.func @vnni_brgemm_strided(%arg0: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>,
%arg1: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>,
%arg2: memref<8x8xbf16>) {
linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]}
ins(%arg0, %arg1 : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>)
outs(%arg2 : memref<8x8xbf16>) {
^bb0(%in: bf16, %in_9: bf16, %out: bf16):
%11 = arith.mulf %in, %in_9 : bf16
%12 = arith.addf %out, %11 : bf16
linalg.yield %12 : bf16
}
return
}
// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)>
// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
//
// CHECK-LABEL: func.func @vnni_brgemm_strided(
// CHECK: %[[arg0:.*]]: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, %[[arg1:.*]]: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>, %[[arg2:.*]]: memref<8x8xbf16>) {
// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : bf16
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[expand_shape:.*]] = memref.expand_shape %[[arg0]] {{\[}}[0], [1], [2, 3]] output_shape [8, 8, 4, 2] : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>> into memref<8x8x4x2xbf16, strided<[64, 8, 2, 1], offset: ?>>
// CHECK: %[[read0:.*]] = vector.transfer_read %[[expand_shape]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true, true, true]}
// CHECK: %[[read1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %cst {in_bounds = [true, true, true, true]}
// CHECK: %[[read2:.*]] = vector.transfer_read %[[arg2]][%[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true]}
// CHECK: %[[read3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<add>} %[[read0]], %[[read1]], %[[read2]] : vector<8x8x4x2xbf16>, vector<8x4x8x2xbf16> into vector<8x8xbf16>
// CHECK: vector.transfer_write %[[read3]], %[[arg2]][%[[c0]], %[[c0]]] {in_bounds = [true, true]}

// -----

#map = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d3 floordiv 2, d2, d0)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>

func.func @non_square_vnni_gemm(%arg0: memref<64x16xbf16>,
%arg1: memref<8x64x2xbf16>, %arg2: memref<64x64xbf16>) {
linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : memref<64x16xbf16>, memref<8x64x2xbf16>)
outs(%arg2 : memref<64x64xbf16>) {
^bb0(%in: bf16, %in_2: bf16, %out: bf16):
%1 = arith.mulf %in, %in_2 : bf16
%2 = arith.addf %out, %1 : bf16
linalg.yield %2 : bf16
}
return
}

// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
// CHECK: #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
//
// CHECK-LABEL: func.func @non_square_vnni_gemm(
// CHECK: %[[arg0:.*]]: memref<64x16xbf16>, %[[arg1:.*]]: memref<8x64x2xbf16>, %[[arg2:.*]]: memref<64x64xbf16>) {
// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : bf16
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[expand_shape:.*]] = memref.expand_shape %[[arg0]] {{\[}}[0], [1, 2]] output_shape [64, 8, 2]
// CHECK %[[read0:.*]] = vector.transfer_read %[[expand_shape]][%[[c0]], %[[c0]], %[[c0]]], %cst {in_bounds = [true, true, true]}
// CHECK %[[read1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true, true]}
// CHECK: %[[read2:.*]] = vector.transfer_read %[[arg2]][%[[c0]], %[[c0]]], %[[cst]] {in_bounds = [true, true]}
// CHECK: %[[read3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[read0]], %[[read1]], %[[read2]]
// CHECK: vector.transfer_write %3, %[[arg2]][%[[c0]], %[[c0]]] {in_bounds = [true, true]} : vector<64x64xbf16>, memref<64x64xbf16>