Skip to content

Commit

Permalink
duplicate the kernel if we cannot shard on num_heads (#30)
Browse files Browse the repository at this point in the history
Merge fix for heads % 2 != 0
  • Loading branch information
aws-zhehongb authored Dec 11, 2024
1 parent e9e3279 commit d1d7da8
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions axlearn/common/flash_attention/neuron_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ def _mha_forward(query, key, value, bias, causal, softmax_scale):
from neuronxcc.nki.kernels.attention import flash_fwd
seed = jnp.array([1])

# Call the NKI kernel
assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f"unexpected num_heads: {num_heads}"
# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads
if (num_heads % 2) == 0 and (num_heads // 2 > 0):
grid = batch_size, nl.nc(lnc) * (num_heads // lnc)
else:
grid = batch_size, num_heads

if bias != None:
attn_output, lse = flash_fwd[batch_size, nl.nc(lnc) * (num_heads // lnc)](
attn_output, lse = flash_fwd[grid](
q,
k,
v,
Expand All @@ -47,7 +50,7 @@ def _mha_forward(query, key, value, bias, causal, softmax_scale):
dropout_p=0.0,
)
else:
attn_output, lse = flash_fwd[batch_size, nl.nc(lnc) * (num_heads // lnc)](
attn_output, lse = flash_fwd[grid](
q,
k,
v,
Expand Down Expand Up @@ -78,10 +81,14 @@ def _mha_backward(causal, softmax_scale, res, d_attn_output):
from neuronxcc.nki.kernels.attention import flash_attn_bwd
import neuronxcc.nki.language as nl

# Call the NKI kernel
assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f"unexpected num_heads: {num_heads}"
# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads
if (num_heads % 2) == 0 and (num_heads // 2 > 0):
grid = batch_size, nl.nc(lnc) * (num_heads // lnc)
else:
grid = batch_size, num_heads

if bias != None:
d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(lnc) * (num_heads // lnc)](
d_query, d_key, d_value = flash_attn_bwd[grid](
q,
k,
v,
Expand All @@ -96,7 +103,7 @@ def _mha_backward(causal, softmax_scale, res, d_attn_output):
softmax_scale=softmax_scale,
)
else:
d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(lnc) * (num_heads // lnc)](
d_query, d_key, d_value = flash_attn_bwd[grid](
q,
k,
v,
Expand Down

0 comments on commit d1d7da8

Please sign in to comment.