Using FlexAttention to compute attention with different masking patterns.
The speedup over F.sdpa/xFormers and FA2 tends to increase with increasing sequence length. Timing plots are shown for different sequence lengths. It is mentioned in the title of the plot.
Mask | Execution Time |
---|---|
Mask | Execution Time |
---|---|
Mask | Execution Time |
---|---|
Mask | Execution Time |
---|---|
Mask | Execution Time |
---|---|
Mask | Execution Time |
---|---|
Mask | Execution Time |
---|---|
Mask | Execution Time |
---|---|
Mask | Execution Time |
---|---|
(Reference - attention-gym repo)
Mask | Execution Time |
---|---|
- Pytorch Nightly (for FlexAttention, to be released with Pytorch 2.5)
- Refer
requirements.txt
for other requirements