Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom mask slows down attention. #724

Open
qiyuxinlin opened this issue Jan 8, 2025 · 4 comments
Open

Custom mask slows down attention. #724

qiyuxinlin opened this issue Jan 8, 2025 · 4 comments

Comments

@qiyuxinlin
Copy link

I noticed that in your previous version, you converted the float-type mask into a bit-packed array for mask usage. I would like to ask how much time this approach saves? I tested the execution time of the bit-packed array mask and the casual kernel, and I found that it runs about twice as slow. This still seems like a significant overhead.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 8, 2025

Yes custom mask has significant overhead, the memory access pattern to custom mask is not coalesced.

For long sequence, it's encouraged to use the sparse API instead:
https://docs.flashinfer.ai/api/sparse.html

@ZhongYingMatrix
Copy link

Is there a better way to support tree attention (speculative decoding) than with a custom mask?

@yzh119
Copy link
Collaborator

yzh119 commented Jan 18, 2025

@ZhongYingMatrix as I mentioned, using the sparse attention API would be much faster, especially for long context.

See this unittest for their equivalence:
https://github.com/flashinfer-ai/flashinfer/blob/a0e99a3a820109763d9a757138a5cdf7bbcd1f85/tests/test_block_sparse.py

@qingquansong
Copy link

qingquansong commented Jan 18, 2025

Hey @yzh119 thanks for sharing the context! I have a similar issue but different context related. Suppose my mask is a dense block wise mask with the following type of mask structure, which is kind of dense for most part. Any API you suggest to use? I tried both but since my case have a quite dense block left side so both api seems have high latency overhead. The sparse one could be slightly faster if making it sparser but in general seems block wise sparsity is a bit hard to be leveraged well here to increase the speed. Other options is calling the api multiple times but may introduce high latency as well. (I'm inserting this to sglang currently) Thank you!

More general question would be,what kind of sparse ratio or pattern you would suggestion to use sparse api rather than dense one?

Also,curious about the timeline for supporting fp8 for FA3 in 0.2.1 and whether all sliding window and custom mask are supported jointly in FA3. Thank you in advance!)

Image
(The first part could be quite dense in the figure and could be quite long such as 32k X 24k and later small triangular blocks could have many (such as 50 small triangular blocks with 100 * 100 each) I'm trying 0.1.6 version btw.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants