Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VNNI gemm vectorization fix #994

Closed

Conversation

KavithaTipturMadhu
Copy link
Contributor

@KavithaTipturMadhu KavithaTipturMadhu commented Dec 5, 2024

This PR fixes the VNNI contract lowering for gemms from linalg generics. There was a bug in the older code which did not consider the gemm iterators which would have size 4: 1 for vnni and 3 others corresponding to gemm. Also, the index setting of map was incorrect for VNNI, it was hardcoded to index 3, but should've been a function of the map's size, which I have changed now. I have left the strided VNNI test be, because it covers a positive VNNI example, although it would've worked with the older code as well. The test non_square_vnni_gemm covers both 4 iterator case, as well as setting the index of the map correctly.

@@ -42,7 +42,7 @@ struct LinalgGenericToVector : OpRewritePattern<linalg::GenericOp> {
if (xsmm::utils::getDataType(rewriter, linalgOp.getOperand(0).getType()) ==
xsmm::DataTypeAttr::get(rewriter.getContext(),
xsmm::DataType::BF16) &&
linalgOp.getIteratorTypes().size() >= 5 &&
linalgOp.getIteratorTypes().size() >= 4 &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which case this change covers? The test example has 5 iterators as I'd expect from VNNI BRGEMM

Copy link
Contributor Author

@KavithaTipturMadhu KavithaTipturMadhu Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update the test for a 4 iterators case, sorry I kept looking at the other part of the change that applies to the example but didn't look at the iterators.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the right check. The VNNI shape is +1 on the non-VNNI shape and not just larger than some random constant.

I'm also not sure what the numOperands check is supposed to do, because all our operations have three operands.

Copy link
Contributor Author

@KavithaTipturMadhu KavithaTipturMadhu Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it was wrong before, this is fixed now, because it is supposed to be 3 iterators for gemm: Parallel, Parallel, Reduction, and vnni introduces one extra reduction dimension which makes it 4. On the other hand, in case of fp32, you need at least 4 dimensions in order to introduce an expand shape. Therefore, this is not a random check. As for numoperands, I agree that it is redundant, earlier, I had written that template for all operators which was why I had the check for numoperands to be three, I'll get rid of the numoperands check.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this expands coverage to VNNI GEMM, alright. The logic is a bit iffy as we don't really validate that it is actually a VNNI format but I think in most mismatch cases the reassociations fails so nothing gets broken.

Overall, this wouldn't be needed as at all if we move away from this floordiv 2 indexing map.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not arguing it doesn't work, I'm arguing it's not fixing the bug.

It's just nudging it to work on the extra case you have and will need more nudges for future cases, ad infinitum. We need to fix the actual bug by either adding a reasonable comparison (relative dimensions between inputs) or using a more appropriate affine map as Adam suggests.

On the other hand, in case of fp32, you need at least 4 dimensions in order to introduce an expand shape.

Can you explain the reasoning here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed - but a proper fix is much larger work. If a nudge is sufficient to unblock current vectorization PoC, I'd go with that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not arguing it doesn't work, I'm arguing it's not fixing the bug.

It's just nudging it to work on the extra case you have and will need more nudges for future cases, ad infinitum. We need to fix the actual bug by either adding a reasonable comparison (relative dimensions between inputs) or using a more appropriate affine map as Adam suggests.

On the other hand, in case of fp32, you need at least 4 dimensions in order to introduce an expand shape.

Can you explain the reasoning here?

Sorry I misspoke, there is no fp32 case, its only bf16 and it needs atleast 4 dimensions to be a gemm. However, as @rengolin pointed out, its not sufficient to check that the number of iteratortypes is 4, we need to check that the iterators match that of a gemm/brgemm and the operations correspond to a gemm/brgemm. I'll update the patch. Thanks!

@rengolin
Copy link
Contributor

rengolin commented Dec 5, 2024

Can you provide a description of what the problem is, how you're trying to fix it and what were alternatives that you considered but didn't take? I don't know what this PR is trying to do.

@KavithaTipturMadhu KavithaTipturMadhu changed the title VNNI brgemm vectorization fix VNNI gemm vectorization fix Dec 6, 2024
Copy link
Collaborator

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a fine temporary workaround
Overall, we just need to fix our affine maps for VNNI layouts but I'd do this refactoring after current pdll PRs

Comment on lines 680 to 683
if (!linalg::isParallelIterator(iteratorTypes[m]) ||
!linalg::isParallelIterator(iteratorTypes[n])) {
return failure();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Isn't this guaranteed by inference logic? That is, I think M and N are always parallel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, will remove

return failure();
}

auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it just fail on size 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On size 0, we still need to get vnni dimension, which we don't get currently as we don't have the reduction dimension present in both M and N maps. Since we know there's an inner reduction due to the previous check, we can safely say there's a 0th dimension corresponding to the k here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right vnni dim won't show up at all.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait I'm a bit lost now, wasn't this change suppose to catch the floordiv dim?

Copy link
Contributor

@rengolin rengolin Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's true. Just checking the last dim reduction isn't enough, as this could just be the k reduction and a non-VNNI bf16 matmul.

I am afraid I didn't understand, I am checking for both last dim reduction as well as k reduction, but if there's no k dimension, it could still be a VNNI gemm, which is ensured by the previous check. It's never a k reduction only, because of the previous check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait I'm a bit lost now, wasn't this change suppose to catch the floordiv dim?

Sorry I realize now the 0 assignment wasn't necessary. I have removed that.

llvm_unreachable("invalid binary op index");
}

LogicalResult checkVNNIGemmStructure(PatternRewriter &rewriter,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it also validate that the innermost dimension is a valid VNNI factor?

}

// innermost dimension must be a reduction dimension for VNNI type operations
if (!linalg::isReductionIterator(iteratorTypes[iteratorTypes.size() - 1])) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VNNI must have two reductions, no? Why are you not checking in the same way as m and n?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am, the thing is contraction dims infer api does not return vnni dimension as one of the ks and thats why I need to check the innermost dimension as a reduction seperately, followed by the k as a floordiv, which is what I do here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the dim infer should check VNNI directly? This is confusing.

return failure();
}

auto k = contractionDims->k.size() > 0 ? contractionDims->k.back() : 0;
Copy link
Contributor

@rengolin rengolin Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's true. Just checking the last dim reduction isn't enough, as this could just be the k reduction and a non-VNNI bf16 matmul.

I am afraid I didn't understand, I am checking for both last dim reduction as well as k reduction, but if there's no k dimension, it could still be a VNNI gemm, which is ensured by the previous check. It's never a k reduction only, because of the previous check.

// Ensure that the body of the generic operation is mul-add chain
// clang-format off
using namespace mlir::structured_match;
auto hasRightOpChain =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd call it isMatmulChain, because "right" is subjective and non descriptive.

} // namespace

FailureOr<linalg::ContractionDimensions>
inferContractionDims(linalg::GenericOp linalgOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized, this is a copy of the upstream code. This really should not happen.

This fix need to be upstream. Let's stop work in this PR for now, please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to copy the upstream code because the upstream code does not get the floordiv dimension if its an affinebinaryopdimexpr. I forgot to discuss this with you.

@adam-smnk
Copy link
Collaborator

As discussed offline, we'll refactor our affine maps to avoid temporary fixes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants