Skip to content

Commit

Permalink
clean up flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 11, 2024
1 parent 8de3b9c commit 3c03930
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 52 deletions.
59 changes: 11 additions & 48 deletions axlearn/common/flash_attention/neuron_attention.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,12 @@
import os
from absl import logging
from functools import partial
import jax
import jax.numpy as jnp
from functools import partial
import jax.numpy as jnp


# enable buffer donation in neuron
jax._src.interpreters.mlir._platforms_with_donation.append("neuron")

if "LNC" not in os.environ:
raise ValueError("LNC environment variable is not set")

cores_per_lnc = os.environ["LNC"]
if cores_per_lnc == "2":
use_lnc = True
elif cores_per_lnc == "1":
use_lnc = False
else:
raise ValueError("LNC environment variable must be set to '1' or '2'")

if use_lnc:
from neuronxcc.nki._private_kernels.attention import (
flash_fwd_shardable,
flash_attn_bwd_shardable,
)
from neuronxcc.starfish.penguin.targets.nki.private_api import vnc

from jax import custom_vjp
import os

lnc = 2 if jax.devices()[0].device_kind == "NC_v3d" else 1

@partial(custom_vjp, nondiff_argnums=(4, 5))
def flash_attention(query, key, value, bias, causal, softmax_scale):
Expand All @@ -50,24 +28,14 @@ def _mha_forward(query, key, value, bias, causal, softmax_scale):
v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model]

import neuronxcc.nki.language as nl
from neuronxcc.nki.kernels.attention import flash_attn_bwd, flash_fwd

# Create the output buffer
attn_output_shape = jax.ShapeDtypeStruct(
(batch_size, num_heads, q_seq_len, d_model), dtype=query.dtype
)
lse_shape = jax.ShapeDtypeStruct(
(batch_size, num_heads, nl.tile_size.pmax, q_seq_len // nl.tile_size.pmax),
dtype=jnp.float32,
)
from neuronxcc.nki.kernels.attention import flash_fwd
seed = jnp.array([1])

# Call the NKI kernel

assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f"unexpect num_heads: {num_heads}"

assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f"unexpected num_heads: {num_heads}"

if bias != None:
attn_output, lse = flash_fwd[batch_size, nl.nc(2) * (num_heads // 2)](
attn_output, lse = flash_fwd[batch_size, nl.nc(lnc) * (num_heads // lnc)](
q,
k,
v,
Expand All @@ -79,7 +47,7 @@ def _mha_forward(query, key, value, bias, causal, softmax_scale):
dropout_p=0.0,
)
else:
attn_output, lse = flash_fwd[batch_size, nl.nc(2) * (num_heads // 2)](
attn_output, lse = flash_fwd[batch_size, nl.nc(lnc) * (num_heads // lnc)](
q,
k,
v,
Expand All @@ -98,27 +66,22 @@ def _mha_forward(query, key, value, bias, causal, softmax_scale):
def _mha_backward(causal, softmax_scale, res, d_attn_output):
lse, o, q, k, v, bias = res
batch_size, num_heads, d_model, seq_len = q.shape
_, kv_seq_len, _, _ = k.shape

# Transpose the input tensors
o = o.transpose(0, 2, 3, 1)
dy = d_attn_output.transpose(0, 2, 3, 1)

# Transpose v tensor
v = jnp.transpose(v, axes=(0, 1, 3, 2))
# Create the output buffer shapes
d_query_shape = jax.ShapeDtypeStruct(q.shape, q.dtype)
d_key_shape = jax.ShapeDtypeStruct(k.shape, k.dtype)
d_value_shape = jax.ShapeDtypeStruct(v.shape, v.dtype)
seed = jnp.array([1])

from neuronxcc.nki.kernels.attention import flash_attn_bwd
import neuronxcc.nki.language as nl

# Call the NKI kernel
assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f"unexpected num_heads: {num_heads}"
if bias != None:
assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f"unexpected num_heads: {num_heads}"
d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(2) * (num_heads // 2)](
d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(lnc) * (num_heads // lnc)](
q,
k,
v,
Expand All @@ -133,7 +96,7 @@ def _mha_backward(causal, softmax_scale, res, d_attn_output):
softmax_scale=softmax_scale,
)
else:
d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(2) * (num_heads // 2)](
d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(lnc) * (num_heads // lnc)](
q,
k,
v,
Expand Down
8 changes: 4 additions & 4 deletions axlearn/common/flash_attention/neuron_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
@pytest.mark.parametrize(
"batch_size,seq_len,num_heads,per_head_dim",
[
(1, 2048, 1, 64),
# (1, 2048, 1, 64),
(2, 2048, 2, 64),
(1, 2048, 1, 128),
# (1, 2048, 1, 128),
(2, 2048, 2, 128),
(1, 2048, 8, 128),
(2, 2048, 8, 128),
Expand Down Expand Up @@ -71,9 +71,9 @@ def impl(q, k, v, bias):
@pytest.mark.parametrize(
"batch_size,num_heads,seq_len,per_head_dim",
[
(1, 1, 2048, 64),
# (1, 1, 2048, 64),
(2, 2, 2048, 64),
(1, 1, 2048, 128),
# (1, 1, 2048, 128),
(2, 2, 2048, 128),
(1, 8, 2048, 128),
(2, 8, 2048, 128),
Expand Down

0 comments on commit 3c03930

Please sign in to comment.