Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention #931

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

oliverdutton
Copy link

@oliverdutton oliverdutton commented Apr 20, 2024

Flash attention implemented to reduce runtime and memory usage using Pallas. Added on opt-in basis in the global config.

For a 759 residue protein and model_5 this drops peak memory consumption to 5 GB without minibatching and reduces runtime 2.3x on an A100 (15.2 $\rightarrow$ 6.5 seconds [with minibatching of 256 for non-flash attention to avoid OOM])

Here's a colab link showing runtime improvement and no significant change in prediction output by visual inspection

When combined with #930 (bfloat16 support for monomer models) peak memory drops to only 2.7 GB and runtime to 5.6 seconds (2.7x speedup relative to non-flash, float32)

Notes:

Key variations from a reference flash attention kernel are:

  • Attention logit biasing supported
  • Gating supported
  • Some heads have only 8 channels, they’re padded up to 16 within kernel (this is a requirement of pl.dot, we still see performance improvement relative to non-flash attn and keeps overall AlphaFold2 linear in memory requirements)
  • Broadcasted masks in batch, q and head dimensions supported (they’re often size 1 and implicitly broadcasted in AlphaFold2 einsums)

There's guards against kernel being called for short sequence lengths less than block sizes specified in q and k which exits to reference kernel.

I haven't done correctness checks with multimer models, I would do if there was a positive response to this pull request.
I'm not certain on the numerical stability of the implementation yet with bfloat16

(I can switch out the exp and log for exp2 and log2 for a small reduction in runtime, this leads to slightly different predictions but with testing I believe would show equivalent error in structure prediction)

@sokrypton
Copy link

Hi @oliverdutton ! Really cool contribution. Mind we try add it to colabfold? We already have fused attention and bfloat16 integrated into monomer model. Will be interesting to try flash attention as well.

TemplateEmbedding uses attention with batch dim broadcast which wasn't supported
`mask = template_mask[None, None, None,:]`
@oliverdutton
Copy link
Author

oliverdutton commented Apr 21, 2024

@sokrypton Of course, I've made a pull request in ColabDesign with it (sokrypton/ColabDesign#173)

Removes any OOB indexing. Previously I've allowed out-of-bounds loads and fixed them by masks in qk. I've seen nan's appear which disappear with minorly varying MHLO. This commit removes any OOB indexing.
@oliverdutton
Copy link
Author

oliverdutton commented Apr 23, 2024

Pre d4516d8 I find transient NaN behaviour on shapes which don't evenly divide block size (so OOB loading).

gist to reproduce problem:

import jax
from jax import jit, numpy as jnp
from alphafold.model import model

key = jax.random.PRNGKey(42)
nrepeats = 100
for nres in range(128,256):
    print(nres)
    for i in range(nrepeats):
        q, k, v = jax.random.uniform(key, (3, 1024, nres, 8, 32))
        f = jax.jit(model.modules.Attention.flash_kernel, static_argnames=(
            'return_residual', 'block_q', 'block_k', 'num_warps', 'num_stages', 'grid', 'interpret', 'debug')
        )
        assert jnp.isfinite(f(q,k,v)).all(), f"Failed with {nres} on run {i}"

Post d4516d8 transient NaN behaviour error disappears. So I hope this will now always be NaN free.

@xlminfei
Copy link

xlminfei commented May 4, 2024

Thank you very much, this improvement is very useful. I am using RTX3090 to predict a 3645aa heterotetramer. With this improvement, the prediction time of a single model has decreased from 59,000 seconds to 43,000 seconds (also out of GPU memory limit).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants