Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

att_mask usage in ScaledDotProduct #846

Closed
naouarmehdi opened this issue Sep 3, 2023 · 8 comments
Closed

att_mask usage in ScaledDotProduct #846

naouarmehdi opened this issue Sep 3, 2023 · 8 comments

Comments

@naouarmehdi
Copy link

🐛 Bug

Using an attention mask in the forward function of ScaledDotProduct leads to the usage of the nonoptimized code path in _matmul_with_mask (in xformers.components.attention.core)

Command To Reproduce

Steps to reproduce the behavior:

q = torch.rand(2, 1, 2, 4)
k = torch.rand(2, 1, 2, 4)
v = torch.rand(2, 1, 2, 4)
mask = torch.tensor([[True, False], [False, True]])
attn = ScaledDotProduct()
test = attn(q=q, k=k, v=v, att_mask=mask)

Expected behavior

The forward function of the ScaledDotProduct class should use the efficient matmul operation of Xformers (torch.ops.xformers.matmul_with_mask(a, b, mask))

Environment

  • PyTorch Version (e.g., 1.0): 2.0.1
  • OS (e.g., Linux): Ubuntu 22.04.1 LTS (x86_64)
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): 0.0.21
  • Python version: 3.9.16
  • CUDA/cuDNN version: 11.7
  • GPU models and configuration: NVIDIA RTX A6000
  • Any other relevant information:

Additional context

After some investigations, I noticed the following:

  1. The forward function of ScaledDotProduct first converts the mask into an AttentionMask with dtype float
  2. Later on (in xformers.components.attention.core line84), a check is performed in order to check which code path will be executed for the matmul multiplication with mask: if _has_cpp_library and mask.dtype == torch.bool:
  3. As you can see, this condition will always yield False since the AttentionMask converted the type of the mask to a float.
  4. Since this condition is violated, _matmul_with_mask will perform the inefficient computation using pytorch matrix multiplication followed by an addition of the attn mask

For now, I have been able to avoid this issue using the scaled_dot_product_attention function from xformers.components.attention.core directly. This prevents the conversion of the attn_mask into float tensor enabling the usage of torch.ops.xformers.matmul_with_mask(a, b, mask)

@danthe3rd
Copy link
Contributor

Hi @naouarmehdi
Unfortunately this part of the xFormers library didn't receive a lot of care in the recent past, and we plan to deprecate/remove it (see also #848).
In the future, we recommend that you use memory_efficient_attention from xformers.ops, which should be faster - but might not support all use-cases supported in ScaledDotProduct tho ..

@naouarmehdi
Copy link
Author

Hi @danthe3rd,
Thank you for your feedback, I ended up using memory_efficient_attention after noticing the significantly better performance.

I have two questions regarding the usage of memory_efficient_attention instead of ScaledDotProduct: I was following your tutorial for the usage of Attention Masks and Sparse Attention:

  • In order to use an attention mask in the memory_efficient_attention function, I have to convert the attention mask to its float format (0 for true, and -inf for False) and use it as attn_bias. Am I right?
  • I have very sparse Attention masks, however, I barely noticed a difference in performance between global and Sparse attention when using memory_efficient_attention. I understood that the attn_bias is simply added to the attn_mask before the softmax operation and therefore doesn't reduce the computation overhead of the attention computation. If I am right, do you plan on supporting Sparse Attention using the memory_efficient_attention in the future?

@danthe3rd
Copy link
Contributor

have to convert the attention mask to its float format (0 for true, and -inf for False) and use it as attn_bias. Am I right?

That is correct

I understood that the attn_bias is simply added to the attn_mask before the softmax operation and therefore doesn't reduce the computation overhead of the attention computation

This is also correct

do you plan on supporting Sparse Attention using the memory_efficient_attention in the future?

We have some limited use cases for some very specific sparsity. For instance block-diagonal sparsity (which can be used to have full dense attention across sentences of different lengths stacked together). We also support causality (which removes half of the calculation).
What type of sparsity are you interested in?

@TsingWei
Copy link

TsingWei commented Sep 19, 2023

Hi @danthe3rd ,

have to convert the attention mask to its float format (0 for true, and -inf for False) and use it as attn_bias. Am I right?

That is correct

I understood that the attn_bias is simply added to the attn_mask before the softmax operation and therefore doesn't reduce the computation overhead of the attention computation

This is also correct

do you plan on supporting Sparse Attention using the memory_efficient_attention in the future?

We have some limited use cases for some very specific sparsity. For instance block-diagonal sparsity (which can be used to have full dense attention across sentences of different lengths stacked together). We also support causality (which removes half of the calculation). What type of sparsity are you interested in?

What is block-dagonal? Does the attention mask look like
1 1 0 0 0 0
1 1 0 0 0 0
0 0 1 1 0 0
0 0 1 1 0 0
0 0 0 0 1 1
0 0 0 0 1 1,
or
1 1 0 0 0 0
1 1 0 0 0 0
1 1 1 1 0 0
1 1 1 1 0 0
1 1 1 1 1 1
1 1 1 1 1 1?
(suppose 0 cancels the attention connection.)
Or support both types?

@danthe3rd
Copy link
Contributor

It supports only the first type

@naouarmehdi
Copy link
Author

Hi @danthe3rd, thank you for your feedback. Block-diagonal sparsity is in fact useful for some of my projects. However, the sparsity I was most interested in is the local attention based on the 2d distance (xformers.components.attention.attention_patterns.local_2d_pattern). Is this one already supported?

@danthe3rd
Copy link
Contributor

Unfortunately this is not supported at the moment

@naouarmehdi
Copy link
Author

Alright, thank you for your informative feedback. I will go ahead and close this ticket as I no longer have any outstanding questions.

bertmaher pushed a commit to bertmaher/xformers that referenced this issue Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants