-
Notifications
You must be signed in to change notification settings - Fork 183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom mask slows down attention. #724
Comments
Yes custom mask has significant overhead, the memory access pattern to custom mask is not coalesced. For long sequence, it's encouraged to use the sparse API instead: |
Is there a better way to support tree attention (speculative decoding) than with a custom mask? |
@ZhongYingMatrix as I mentioned, using the sparse attention API would be much faster, especially for long context. See this unittest for their equivalence: |
Hey @yzh119 thanks for sharing the context! I have a similar issue but different context related. Suppose my mask is a dense block wise mask with the following type of mask structure, which is kind of dense for most part. Any API you suggest to use? I tried both but since my case have a quite dense block left side so both api seems have high latency overhead. The sparse one could be slightly faster if making it sparser but in general seems block wise sparsity is a bit hard to be leveraged well here to increase the speed. Other options is calling the api multiple times but may introduce high latency as well. (I'm inserting this to sglang currently) Thank you! More general question would be,what kind of sparse ratio or pattern you would suggestion to use sparse api rather than dense one? Also,curious about the timeline for supporting fp8 for FA3 in 0.2.1 and whether all sliding window and custom mask are supported jointly in FA3. Thank you in advance!)
|
I noticed that in your previous version, you converted the float-type mask into a bit-packed array for mask usage. I would like to ask how much time this approach saves? I tested the execution time of the bit-packed array mask and the casual kernel, and I found that it runs about twice as slow. This still seems like a significant overhead.
The text was updated successfully, but these errors were encountered: