-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VNNI gemm vectorization fix #994
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,21 +8,27 @@ | |
|
||
#include "TPP/Dialect/Xsmm/XsmmUtils.h" | ||
#include "TPP/Dialect/Xsmm/XsmmOps.h" | ||
#include "TPP/IR/StructuredOpMatcher.h" | ||
#include "TPP/Transforms/Utils/BuilderUtils.h" | ||
#include "TPP/Transforms/Utils/VNNIUtils.h" | ||
#include "TPP/Transforms/Utils/ValueUtils.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/Linalg/Utils/Utils.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/Utils/IndexingUtils.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/IR/BuiltinTypeInterfaces.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
|
||
#include "TPP/Transforms/Utils/BuilderUtils.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/ADT/SetOperations.h" | ||
#include "llvm/Support/Debug.h" | ||
#define DEBUG_TYPE "xsmm-utils" | ||
|
||
using namespace mlir; | ||
using namespace mlir::linalg; | ||
|
||
namespace mlir { | ||
namespace xsmm { | ||
namespace utils { | ||
|
@@ -564,6 +570,144 @@ FailureOr<SmallVector<Attribute>> getBrgemmFlags(PatternRewriter &rewriter, | |
return attributes; | ||
} | ||
|
||
static llvm::SmallDenseSet<int64_t> | ||
findIndexingOperand(AffineMap indexingMap, | ||
ArrayRef<mlir::utils::IteratorType> iterators, | ||
mlir::utils::IteratorType iter) { | ||
assert(iterators.size() == indexingMap.getNumDims()); | ||
llvm::SmallDenseSet<int64_t> res; | ||
for (AffineExpr e : indexingMap.getResults()) { | ||
int position = -1; | ||
if (isa<AffineDimExpr>(e)) { | ||
auto expr = dyn_cast<AffineDimExpr>(e); | ||
position = expr.getPosition(); | ||
} else if (isa<AffineBinaryOpExpr>(e)) { | ||
auto lhs = dyn_cast<AffineBinaryOpExpr>(e).getLHS(); | ||
assert(isa<AffineDimExpr>(lhs)); | ||
position = (dyn_cast<AffineDimExpr>(lhs)).getPosition(); | ||
} | ||
assert(position >= 0); | ||
if (iterators[position] == iter && | ||
llvm::count_if(indexingMap.getResults(), [position](AffineExpr e) { | ||
return e.isFunctionOfDim(position); | ||
}) == 1) | ||
res.insert(position); | ||
} | ||
return res; | ||
} | ||
namespace { | ||
auto par = mlir::utils::IteratorType::parallel; | ||
auto red = mlir::utils::IteratorType::reduction; | ||
} // namespace | ||
|
||
FailureOr<linalg::ContractionDimensions> | ||
inferContractionDims(linalg::GenericOp linalgOp) { | ||
auto indexingMaps = linalgOp.getIndexingMapsArray(); | ||
auto iterators = linalgOp.getIteratorTypesArray(); | ||
llvm::SmallDenseSet<int64_t> a = | ||
findIndexingOperand(indexingMaps[0], iterators, par); | ||
llvm::SmallDenseSet<int64_t> b = | ||
findIndexingOperand(indexingMaps[1], iterators, par); | ||
llvm::SmallDenseSet<int64_t> c = | ||
findIndexingOperand(indexingMaps[2], iterators, par); | ||
|
||
// A & C - B are the iterators involved in an outer-product along A (the LHS). | ||
llvm::SmallDenseSet<int64_t> ac = a; | ||
llvm::set_intersect(ac, c); | ||
llvm::set_subtract(ac, b); | ||
// B & C - A are the iterators involved in an outer-product along B (the RHS). | ||
llvm::SmallDenseSet<int64_t> bc = b; | ||
llvm::set_intersect(bc, c); | ||
llvm::set_subtract(bc, a); | ||
// A & B & C are the "batch" dimensions. | ||
llvm::SmallDenseSet<int64_t> batches = a; | ||
llvm::set_intersect(batches, b); | ||
llvm::set_intersect(batches, c); | ||
|
||
// A & B red are the reduction dimensions. | ||
llvm::SmallDenseSet<int64_t> ra = | ||
findIndexingOperand(indexingMaps[0], iterators, red); | ||
llvm::SmallDenseSet<int64_t> rb = | ||
findIndexingOperand(indexingMaps[1], iterators, red); | ||
llvm::set_intersect(ra, rb); | ||
|
||
// Return each set in sorted order. | ||
ContractionDimensions dimensions{ | ||
SmallVector<unsigned, 2>(batches.begin(), batches.end()), | ||
SmallVector<unsigned, 2>(ac.begin(), ac.end()), | ||
SmallVector<unsigned, 2>(bc.begin(), bc.end()), | ||
SmallVector<unsigned, 2>(ra.begin(), ra.end())}; | ||
llvm::sort(dimensions.batch.begin(), dimensions.batch.end()); | ||
llvm::sort(dimensions.m.begin(), dimensions.m.end()); | ||
llvm::sort(dimensions.n.begin(), dimensions.n.end()); | ||
llvm::sort(dimensions.k.begin(), dimensions.k.end()); | ||
return dimensions; | ||
} | ||
|
||
std::optional<unsigned> getAffineBinaryOpExprIndex(AffineMap map, int index, | ||
MLIRContext *context) { | ||
for (unsigned i = 0; i < map.getNumResults(); i++) { | ||
auto result = map.getResult(i); | ||
if (isa<AffineBinaryOpExpr>(result) && | ||
dyn_cast<AffineBinaryOpExpr>(result).getLHS() == | ||
getAffineDimExpr(index, context)) { | ||
return i; | ||
} | ||
} | ||
llvm_unreachable("invalid binary op index"); | ||
} | ||
|
||
LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should it also validate that the innermost dimension is a valid VNNI factor? |
||
linalg::GenericOp linalgOp) { | ||
if (linalgOp->getNumOperands() != 3) | ||
return failure(); | ||
|
||
if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) != | ||
xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16)) { | ||
return failure(); | ||
} | ||
auto iteratorTypes = linalgOp.getIteratorTypesArray(); | ||
if (iteratorTypes.size() < 4) | ||
return failure(); | ||
|
||
auto contractionDims = inferContractionDims(linalgOp); | ||
if (failed(contractionDims)) | ||
return failure(); | ||
|
||
unsigned m = contractionDims->m.back(); | ||
unsigned n = contractionDims->n.back(); | ||
|
||
// m and n dimensions must be parallel dimensions | ||
if (!linalg::isParallelIterator(iteratorTypes[m]) || | ||
!linalg::isParallelIterator(iteratorTypes[n])) { | ||
return failure(); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Isn't this guaranteed by inference logic? That is, I think M and N are always parallel There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough, will remove |
||
|
||
// innermost dimension must be a reduction dimension for VNNI type operations | ||
if (!linalg::isReductionIterator(iteratorTypes[iteratorTypes.size() - 1])) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. VNNI must have two reductions, no? Why are you not checking in the same way as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am, the thing is contraction dims infer api does not return vnni dimension as one of the ks and thats why I need to check the innermost dimension as a reduction seperately, followed by the k as a floordiv, which is what I do here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps the dim infer should check VNNI directly? This is confusing. |
||
return failure(); | ||
} | ||
|
||
// get the index of the iterator corresponding to the floordiv operation | ||
auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't it just fail on size 0? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On size 0, we still need to get vnni dimension, which we don't get currently as we don't have the reduction dimension present in both M and N maps. Since we know there's an inner reduction due to the previous check, we can safely say there's a 0th dimension corresponding to the k here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh right vnni dim won't show up at all. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait I'm a bit lost now, wasn't this change suppose to catch the floordiv dim? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think that's true. Just checking the last dim reduction isn't enough, as this could just be the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
auto map1 = linalgOp.getIndexingMapsArray()[1]; | ||
auto index = getAffineBinaryOpExprIndex(map1, k, linalgOp.getContext()); | ||
if (!index) | ||
return failure(); | ||
|
||
// Ensure that the body of the generic operation is mul-add chain | ||
// clang-format off | ||
using namespace mlir::structured_match; | ||
auto hasRightOpChain = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I'd call it |
||
StructuredOpMatcher::make<linalg::GenericOp>() | ||
.region(MatchOne(0), WithOpChain<KindMul, KindAdd>( | ||
/*captures=*/nullptr)); | ||
// clang-format on | ||
if (!hasRightOpChain.match(linalgOp)) | ||
return failure(); | ||
return success(); | ||
} | ||
|
||
template FailureOr<SmallVector<Attribute>> | ||
getBrgemmFlags<xsmm::BrgemmDispatchOp>(PatternRewriter &rewriter, | ||
xsmm::BrgemmDispatchOp dispatchOpTy, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just realized, this is a copy of the upstream code. This really should not happen.
This fix need to be upstream. Let's stop work in this PR for now, please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to copy the upstream code because the upstream code does not get the floordiv dimension if its an affinebinaryopdimexpr. I forgot to discuss this with you.