-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VNNI gemm vectorization fix #994
Conversation
b9363dc
to
f667acf
Compare
lib/TPP/Transforms/Vectorization.cpp
Outdated
@@ -42,7 +42,7 @@ struct LinalgGenericToVector : OpRewritePattern<linalg::GenericOp> { | |||
if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) == | |||
xsmm::DataTypeAttr::get(rewriter.getContext(), | |||
xsmm::DataType::BF16) && | |||
linalgOp.getIteratorTypes().size() >= 5 && | |||
linalgOp.getIteratorTypes().size() >= 4 && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which case this change covers? The test example has 5 iterators as I'd expect from VNNI BRGEMM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update the test for a 4 iterators case, sorry I kept looking at the other part of the change that applies to the example but didn't look at the iterators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is the right check. The VNNI shape is +1 on the non-VNNI shape and not just larger than some random constant.
I'm also not sure what the numOperands
check is supposed to do, because all our operations have three operands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it was wrong before, this is fixed now, because it is supposed to be 3 iterators for gemm: Parallel, Parallel, Reduction, and vnni introduces one extra reduction dimension which makes it 4. On the other hand, in case of fp32, you need at least 4 dimensions in order to introduce an expand shape. Therefore, this is not a random check. As for numoperands, I agree that it is redundant, earlier, I had written that template for all operators which was why I had the check for numoperands to be three, I'll get rid of the numoperands check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this expands coverage to VNNI GEMM, alright. The logic is a bit iffy as we don't really validate that it is actually a VNNI format but I think in most mismatch cases the reassociations fails so nothing gets broken.
Overall, this wouldn't be needed as at all if we move away from this floordiv 2
indexing map.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not arguing it doesn't work, I'm arguing it's not fixing the bug.
It's just nudging it to work on the extra case you have and will need more nudges for future cases, ad infinitum. We need to fix the actual bug by either adding a reasonable comparison (relative dimensions between inputs) or using a more appropriate affine map as Adam suggests.
On the other hand, in case of fp32, you need at least 4 dimensions in order to introduce an expand shape.
Can you explain the reasoning here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed - but a proper fix is much larger work. If a nudge is sufficient to unblock current vectorization PoC, I'd go with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not arguing it doesn't work, I'm arguing it's not fixing the bug.
It's just nudging it to work on the extra case you have and will need more nudges for future cases, ad infinitum. We need to fix the actual bug by either adding a reasonable comparison (relative dimensions between inputs) or using a more appropriate affine map as Adam suggests.
On the other hand, in case of fp32, you need at least 4 dimensions in order to introduce an expand shape.
Can you explain the reasoning here?
Sorry I misspoke, there is no fp32 case, its only bf16 and it needs atleast 4 dimensions to be a gemm. However, as @rengolin pointed out, its not sufficient to check that the number of iteratortypes is 4, we need to check that the iterators match that of a gemm/brgemm and the operations correspond to a gemm/brgemm. I'll update the patch. Thanks!
Can you provide a description of what the problem is, how you're trying to fix it and what were alternatives that you considered but didn't take? I don't know what this PR is trying to do. |
f667acf
to
bd47273
Compare
bd47273
to
c614f3f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a fine temporary workaround
Overall, we just need to fix our affine maps for VNNI layouts but I'd do this refactoring after current pdll PRs
lib/TPP/Dialect/Xsmm/XsmmUtils.cpp
Outdated
if (!linalg::isParallelIterator(iteratorTypes[m]) || | ||
!linalg::isParallelIterator(iteratorTypes[n])) { | ||
return failure(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Isn't this guaranteed by inference logic? That is, I think M and N are always parallel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough, will remove
lib/TPP/Dialect/Xsmm/XsmmUtils.cpp
Outdated
return failure(); | ||
} | ||
|
||
auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it just fail on size 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On size 0, we still need to get vnni dimension, which we don't get currently as we don't have the reduction dimension present in both M and N maps. Since we know there's an inner reduction due to the previous check, we can safely say there's a 0th dimension corresponding to the k here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh right vnni dim won't show up at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait I'm a bit lost now, wasn't this change suppose to catch the floordiv dim?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that's true. Just checking the last dim reduction isn't enough, as this could just be the k
reduction and a non-VNNI bf16 matmul.
I am afraid I didn't understand, I am checking for both last dim reduction as well as k reduction, but if there's no k dimension, it could still be a VNNI gemm, which is ensured by the previous check. It's never a k reduction only, because of the previous check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait I'm a bit lost now, wasn't this change suppose to catch the floordiv dim?
Sorry I realize now the 0 assignment wasn't necessary. I have removed that.
llvm_unreachable("invalid binary op index"); | ||
} | ||
|
||
LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it also validate that the innermost dimension is a valid VNNI factor?
} | ||
|
||
// innermost dimension must be a reduction dimension for VNNI type operations | ||
if (!linalg::isReductionIterator(iteratorTypes[iteratorTypes.size() - 1])) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VNNI must have two reductions, no? Why are you not checking in the same way as m
and n
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am, the thing is contraction dims infer api does not return vnni dimension as one of the ks and thats why I need to check the innermost dimension as a reduction seperately, followed by the k as a floordiv, which is what I do here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps the dim infer should check VNNI directly? This is confusing.
lib/TPP/Dialect/Xsmm/XsmmUtils.cpp
Outdated
return failure(); | ||
} | ||
|
||
auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that's true. Just checking the last dim reduction isn't enough, as this could just be the k
reduction and a non-VNNI bf16 matmul.
I am afraid I didn't understand, I am checking for both last dim reduction as well as k reduction, but if there's no k dimension, it could still be a VNNI gemm, which is ensured by the previous check. It's never a k reduction only, because of the previous check.
lib/TPP/Dialect/Xsmm/XsmmUtils.cpp
Outdated
// Ensure that the body of the generic operation is mul-add chain | ||
// clang-format off | ||
using namespace mlir::structured_match; | ||
auto hasRightOpChain = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'd call it isMatmulChain
, because "right" is subjective and non descriptive.
} // namespace | ||
|
||
FailureOr<linalg::ContractionDimensions> | ||
inferContractionDims(linalg::GenericOp linalgOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just realized, this is a copy of the upstream code. This really should not happen.
This fix need to be upstream. Let's stop work in this PR for now, please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to copy the upstream code because the upstream code does not get the floordiv dimension if its an affinebinaryopdimexpr. I forgot to discuss this with you.
As discussed offline, we'll refactor our affine maps to avoid temporary fixes. |
This PR fixes the VNNI contract lowering for gemms from linalg generics. There was a bug in the older code which did not consider the gemm iterators which would have size 4: 1 for vnni and 3 others corresponding to gemm. Also, the index setting of map was incorrect for VNNI, it was hardcoded to index 3, but should've been a function of the map's size, which I have changed now. I have left the strided VNNI test be, because it covers a positive VNNI example, although it would've worked with the older code as well. The test non_square_vnni_gemm covers both 4 iterator case, as well as setting the index of the map correctly.