-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
att_mask usage in ScaledDotProduct #846
Comments
Hi @naouarmehdi |
Hi @danthe3rd, I have two questions regarding the usage of memory_efficient_attention instead of ScaledDotProduct: I was following your tutorial for the usage of Attention Masks and Sparse Attention:
|
That is correct
This is also correct
We have some limited use cases for some very specific sparsity. For instance block-diagonal sparsity (which can be used to have full dense attention across sentences of different lengths stacked together). We also support causality (which removes half of the calculation). |
Hi @danthe3rd ,
What is block-dagonal? Does the attention mask look like |
It supports only the first type |
Hi @danthe3rd, thank you for your feedback. Block-diagonal sparsity is in fact useful for some of my projects. However, the sparsity I was most interested in is the local attention based on the 2d distance (xformers.components.attention.attention_patterns.local_2d_pattern). Is this one already supported? |
Unfortunately this is not supported at the moment |
Alright, thank you for your informative feedback. I will go ahead and close this ticket as I no longer have any outstanding questions. |
🐛 Bug
Using an attention mask in the forward function of ScaledDotProduct leads to the usage of the nonoptimized code path in _matmul_with_mask (in xformers.components.attention.core)
Command To Reproduce
Steps to reproduce the behavior:
Expected behavior
The forward function of the ScaledDotProduct class should use the efficient matmul operation of Xformers (torch.ops.xformers.matmul_with_mask(a, b, mask))
Environment
conda
,pip
, source): condaAdditional context
After some investigations, I noticed the following:
if _has_cpp_library and mask.dtype == torch.bool:
For now, I have been able to avoid this issue using the scaled_dot_product_attention function from xformers.components.attention.core directly. This prevents the conversion of the attn_mask into float tensor enabling the usage of torch.ops.xformers.matmul_with_mask(a, b, mask)
The text was updated successfully, but these errors were encountered: