-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash attention #931
base: main
Are you sure you want to change the base?
Flash attention #931
Conversation
Hi @oliverdutton ! Really cool contribution. Mind we try add it to colabfold? We already have fused attention and bfloat16 integrated into monomer model. Will be interesting to try flash attention as well. |
TemplateEmbedding uses attention with batch dim broadcast which wasn't supported `mask = template_mask[None, None, None,:]`
@sokrypton Of course, I've made a pull request in ColabDesign with it (sokrypton/ColabDesign#173) |
Removes any OOB indexing. Previously I've allowed out-of-bounds loads and fixed them by masks in qk. I've seen nan's appear which disappear with minorly varying MHLO. This commit removes any OOB indexing.
Pre d4516d8 I find transient NaN behaviour on shapes which don't evenly divide block size (so OOB loading). gist to reproduce problem: import jax
from jax import jit, numpy as jnp
from alphafold.model import model
key = jax.random.PRNGKey(42)
nrepeats = 100
for nres in range(128,256):
print(nres)
for i in range(nrepeats):
q, k, v = jax.random.uniform(key, (3, 1024, nres, 8, 32))
f = jax.jit(model.modules.Attention.flash_kernel, static_argnames=(
'return_residual', 'block_q', 'block_k', 'num_warps', 'num_stages', 'grid', 'interpret', 'debug')
)
assert jnp.isfinite(f(q,k,v)).all(), f"Failed with {nres} on run {i}" Post d4516d8 transient NaN behaviour error disappears. So I hope this will now always be NaN free. |
Thank you very much, this improvement is very useful. I am using RTX3090 to predict a 3645aa heterotetramer. With this improvement, the prediction time of a single model has decreased from 59,000 seconds to 43,000 seconds (also out of GPU memory limit). |
Flash attention implemented to reduce runtime and memory usage using Pallas. Added on opt-in basis in the global config.
For a 759 residue protein and model_5 this drops peak memory consumption to 5 GB without minibatching and reduces runtime 2.3x on an A100 (15.2$\rightarrow$ 6.5 seconds [with minibatching of 256 for non-flash attention to avoid OOM])
Here's a colab link showing runtime improvement and no significant change in prediction output by visual inspection
When combined with #930 (bfloat16 support for monomer models) peak memory drops to only 2.7 GB and runtime to 5.6 seconds (2.7x speedup relative to non-flash, float32)
Notes:
Key variations from a reference flash attention kernel are:
There's guards against kernel being called for short sequence lengths less than block sizes specified in q and k which exits to reference kernel.
I haven't done correctness checks with multimer models, I would do if there was a positive response to this pull request.
I'm not certain on the numerical stability of the implementation yet with bfloat16
(I can switch out the exp and log for exp2 and log2 for a small reduction in runtime, this leads to slightly different predictions but with testing I believe would show equivalent error in structure prediction)