Skip to content

Commit

Permalink
clean flash attention, enable GQA since it works
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 6, 2024
1 parent fd74624 commit eaf45a4
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 102 deletions.
227 changes: 135 additions & 92 deletions axlearn/common/flash_attention/neuron_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,111 +10,154 @@
from jax_neuronx import nki_call
from neuronxcc.nki._private_kernels.legacy.attention import flash_attn_bwd, flash_fwd

if 'LNC' not in os.environ:
# enable buffer donation in neuron
jax._src.interpreters.mlir._platforms_with_donation.append("neuron")

if "LNC" not in os.environ:
raise ValueError("LNC environment variable is not set")

cores_per_lnc = os.environ['LNC']
if cores_per_lnc == '2':
cores_per_lnc = os.environ["LNC"]
if cores_per_lnc == "2":
use_lnc = True
elif cores_per_lnc == '1':
elif cores_per_lnc == "1":
use_lnc = False
else:
raise ValueError("LNC environment variable must be set to '1' or '2'")

disable_sharded_attn_kernel = os.environ.get('DISABLE_SHARDED_ATTN_KERNEL')
if disable_sharded_attn_kernel is not None:
use_lnc = False

if use_lnc:
from neuronxcc.nki._private_kernels.attention import flash_fwd_shardable, flash_attn_bwd_shardable
from neuronxcc.nki._private_kernels.attention import (
flash_fwd_shardable,
flash_attn_bwd_shardable,
)
from neuronxcc.starfish.penguin.targets.nki.private_api import vnc

from jax import custom_vjp

@partial(custom_vjp, nondiff_argnums=(3,4))
def flash_attention(query, key, value, causal, softmax_scale):
out, _ = _mha_forward(query, key, value, causal, softmax_scale)
return out

def _mha_forward(query, key, value, causal, softmax_scale):
# Get the batch size, sequence lengths, number of heads, and hidden dimension
batch_size, q_seq_len, num_heads, d_model = query.shape
_, kv_seq_len, _, _ = key.shape

# Transpose the query, key, and value tensors
q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len]
k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len]
v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model]

# Create the output buffer
import neuronxcc.nki.language as nl
attn_output_shape = jax.ShapeDtypeStruct((batch_size, num_heads, q_seq_len, d_model), dtype=query.dtype)
lse_shape = jax.ShapeDtypeStruct((batch_size, num_heads, nl.tile_size.pmax, q_seq_len // nl.tile_size.pmax), dtype=jnp.float32)
seed = jnp.array([1])

# Call the NKI kernel
if os.environ.get('ENABLE_NEW_UNSHARDED_ATTN_KERNEL'):
from neuronxcc.nki.kernels.attention import flash_attn_bwd, flash_fwd
import neuronxcc.nki.language as nl

assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f'unexpect num_heads: {num_heads}'
attn_output, lse = flash_fwd[batch_size, nl.nc(2) * (num_heads//2)](q, k, v, seed, use_causal_mask=causal, softmax_scale=softmax_scale, mixed_precision=True, dropout_p=0.0)
else:
from neuronxcc.nki._private_kernels.legacy.attention import flash_fwd
from neuronxcc.nki._private_kernels.attention import flash_fwd_shardable
from neuronxcc.starfish.penguin.targets.nki.private_api import vnc
attn_output, lse = nki_call(
partial(flash_fwd_shardable if use_lnc else flash_fwd, use_causal_mask=causal, softmax_scale=softmax_scale, mixed_precision=True, dropout_p=0.0),
q, k, v, seed,
out_shape=(attn_output_shape, lse_shape),
grid=(batch_size, num_heads, vnc(2)) if use_lnc else (batch_size, num_heads)
)
# Transpose the output back to the original shape
attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model]

return attn_output, (lse, attn_output, q, k, v)

@partial(custom_vjp, nondiff_argnums=(4, 5))
def flash_attention(query, key, value, bias, causal, softmax_scale):
# NOTE : Merge with upstream. Old code supports both 2d and 4d bias but upstream code only supports 4d.
# We no longer need 2d logit_bias but should sync how we merge this check with upstream.
# assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}"
out, _ = _mha_forward(query, key, value, bias, causal, softmax_scale)
return out


def _mha_forward(query, key, value, bias, causal, softmax_scale):
# Get the batch size, sequence lengths, number of heads, and hidden dimension
batch_size, q_seq_len, num_heads, d_model = query.shape
_, kv_seq_len, _, _ = key.shape

# Transpose the query, key, and value tensors
q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len]
k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len]
v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model]

import neuronxcc.nki.language as nl
from neuronxcc.nki.kernels.attention import flash_attn_bwd, flash_fwd

# Create the output buffer
attn_output_shape = jax.ShapeDtypeStruct(
(batch_size, num_heads, q_seq_len, d_model), dtype=query.dtype
)
lse_shape = jax.ShapeDtypeStruct(
(batch_size, num_heads, nl.tile_size.pmax, q_seq_len // nl.tile_size.pmax),
dtype=jnp.float32,
)
seed = jnp.array([1])

# Call the NKI kernel

assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f"unexpect num_heads: {num_heads}"

if bias != None:
attn_output, lse = flash_fwd[batch_size, nl.nc(2) * (num_heads // 2)](
q,
k,
v,
seed,
bias,
use_causal_mask=causal,
softmax_scale=softmax_scale,
mixed_precision=True,
dropout_p=0.0,
)
else:
attn_output, lse = flash_fwd[batch_size, nl.nc(2) * (num_heads // 2)](
q,
k,
v,
seed,
use_causal_mask=causal,
softmax_scale=softmax_scale,
mixed_precision=True,
dropout_p=0.0,
)
# Transpose the output back to the original shape
attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model]

return attn_output, (lse, attn_output, q, k, v, bias)


def _mha_backward(causal, softmax_scale, res, d_attn_output):
lse, o, q, k, v = res
batch_size, num_heads, d_model, seq_len = q.shape
_, kv_seq_len, _, _ = k.shape

# Transpose the input tensors
o = o.transpose(0, 2, 3, 1)
dy = d_attn_output.transpose(0, 2, 3, 1)

# Transpose v tensor
v = jnp.transpose(v, axes=(0, 1, 3, 2))
# Create the output buffer shapes
d_query_shape = jax.ShapeDtypeStruct(q.shape, q.dtype)
d_key_shape = jax.ShapeDtypeStruct(k.shape, k.dtype)
d_value_shape = jax.ShapeDtypeStruct(v.shape, v.dtype)
seed = jnp.array([1])

# Call the NKI kernel
if os.environ.get('ENABLE_NEW_UNSHARDED_ATTN_KERNEL'):
from neuronxcc.nki.kernels.attention import flash_attn_bwd
import neuronxcc.nki.language as nl
assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f'unexpected num_heads: {num_heads}'
d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(2) * (num_heads//2)](q, k, v, o, dy, lse, seed, use_causal_mask=causal, mixed_precision=True, dropout_p=0.0, softmax_scale=softmax_scale)
else:
from neuronxcc.nki._private_kernels.legacy.attention import flash_attn_bwd
from neuronxcc.nki._private_kernels.attention import flash_attn_bwd_shardable
from neuronxcc.starfish.penguin.targets.nki.private_api import vnc
d_query, d_key, d_value = nki_call(
partial(flash_attn_bwd_shardable if use_lnc else flash_attn_bwd, use_causal_mask=causal, mixed_precision=True, dropout_p=0.0, softmax_scale=softmax_scale),
q, k, v, o, dy, lse, seed,
out_shape=[d_query_shape, d_key_shape, d_value_shape],
grid=(batch_size, num_heads, vnc(2)) if use_lnc else (batch_size, num_heads)
)

# Batch seq_len heads, head_dim
# Transpose the gradients back to the original shape
d_query = d_query.transpose(0, 3, 1, 2)
d_key = d_key.transpose(0, 3, 1, 2)
d_value = d_value.transpose(0, 3, 1, 2)

return d_query, d_key, d_value
lse, o, q, k, v, bias = res
batch_size, num_heads, d_model, seq_len = q.shape
_, kv_seq_len, _, _ = k.shape

# Transpose the input tensors
o = o.transpose(0, 2, 3, 1)
dy = d_attn_output.transpose(0, 2, 3, 1)

# Transpose v tensor
v = jnp.transpose(v, axes=(0, 1, 3, 2))
# Create the output buffer shapes
d_query_shape = jax.ShapeDtypeStruct(q.shape, q.dtype)
d_key_shape = jax.ShapeDtypeStruct(k.shape, k.dtype)
d_value_shape = jax.ShapeDtypeStruct(v.shape, v.dtype)
seed = jnp.array([1])

from neuronxcc.nki.kernels.attention import flash_attn_bwd
import neuronxcc.nki.language as nl

# Call the NKI kernel
if bias != None:
assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f"unexpected num_heads: {num_heads}"
d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(2) * (num_heads // 2)](
q,
k,
v,
o,
dy,
lse,
seed,
bias,
use_causal_mask=causal,
mixed_precision=True,
dropout_p=0.0,
softmax_scale=softmax_scale,
)
else:
d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(2) * (num_heads // 2)](
q,
k,
v,
o,
dy,
lse,
seed,
use_causal_mask=causal,
mixed_precision=True,
dropout_p=0.0,
softmax_scale=softmax_scale,
)

# Batch seq_len heads, head_dim
# Transpose the gradients back to the original shape
d_query = d_query.transpose(0, 3, 1, 2)
d_key = d_key.transpose(0, 3, 1, 2)
d_value = d_value.transpose(0, 3, 1, 2)

return d_query, d_key, d_value, None

flash_attention.defvjp(_mha_forward, _mha_backward)

flash_attention.defvjp(_mha_forward, _mha_backward)
2 changes: 1 addition & 1 deletion axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def jit_attn(query, key, value, bias, segment_ids):
if segment_ids != None:
raise Exception("Sequence Packing is not supported on Neuron backend")
return neuron_flash_attention(
query, key, value, causal, softmax_scale)
query, key, value, bias, causal, softmax_scale)

return jit_attn

Expand Down
7 changes: 6 additions & 1 deletion axlearn/common/gda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class GDATest(TestCase):
itertools.product(
((1, 1), (8, 1), (4, 2)), # mesh_shape
(1, 16), # per_host_batch_size
(DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition
(DataPartitionType.FULL, DataPartitionType.REPLICATED, DataPartitionType.BATCH), # data_partition
)
)
def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition):
Expand All @@ -43,11 +43,16 @@ def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition
devices = mesh_utils.create_device_mesh(mesh_shape)
if data_partition == DataPartitionType.FULL:
global_batch_size = per_host_batch_size * jax.process_count()
elif data_partition == DataPartitionType.BATCH:
global_batch_size = per_host_batch_size * jax.process_count()
else:
assert data_partition == DataPartitionType.REPLICATED
global_batch_size = per_host_batch_size
if data_partition == DataPartitionType.FULL and global_batch_size < jax.device_count():
return
# first axis is assumed to be batch axis
if data_partition == DataPartitionType.BATCH and global_batch_size % mesh_shape[0] == 0:
return
per_host_input_batch = dict(x=jnp.zeros((per_host_batch_size, 8), dtype=jnp.float32))
with jax.sharding.Mesh(devices, ("data", "model")):
global_input_batch = host_to_global_device_array(
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def make_gda(x, partition_spec):
elif partition == DataPartitionType.REPLICATED:
global_shape = (x.shape[0], *x.shape[1:])
elif partition == DataPartitionType.BATCH:
global_shape = (x.shape[0], *x.shape[1:])
global_shape = (x.shape[0] * process_count, *x.shape[1:])
else:
raise NotImplementedError(f"Unsupported partition: {partition}")
return jax.make_array_from_process_local_data(
Expand Down
11 changes: 4 additions & 7 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ def get_trainer_kwargs(
return {}
max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch
max_sequence_length = MAX_SEQUENCE_LENGTH[version]
# train_batch_size = tokens_per_batch // max_sequence_length
train_batch_size = 16
train_batch_size = tokens_per_batch // max_sequence_length

# Whether to use grouped query attention.
num_kv_heads = None
Expand All @@ -153,7 +152,6 @@ def get_trainer_kwargs(
rope_theta=rope_theta,
shared_lm_head=True,
flash_attention=flash_attention,
stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(),
),
learner_kwargs=dict(peak_lr=6e-4, weight_decay=0.01),
max_sequence_length=64,
Expand Down Expand Up @@ -210,7 +208,6 @@ def get_trainer_kwargs(
rope_theta=rope_theta,
shared_lm_head=True,
flash_attention=flash_attention,
stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(),
),
learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
Expand Down Expand Up @@ -393,18 +390,17 @@ def get_trainer_kwargs(
hidden_dim=128 * 64,
num_heads=64,
# No GQA support in V1 models, so num_kv_heads is the same as num_heads.
num_kv_heads=None,# if version == Version.V1 else 8,
num_kv_heads=None if version == Version.V1 else 8,
# TODO(kelvin-zou): Remove the perf numbers for V5e (OOM).
ffn_dim=scaled_hidden_dim(scale=3.5, round_up_to_multiples_of=256),
rope_theta=rope_theta,
# shared_lm_head=False,
shared_lm_head=True,
flash_attention=True,
stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(),
),
learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
train_batch_size=int((jax.device_count()/4)),
input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(fsdp=-1),
Expand Down Expand Up @@ -447,6 +443,7 @@ def get_trainer_kwargs(
raise NotImplementedError(f"Unknown model size {model_size}.")
model_kwargs = trainer_kwargs.pop("model_kwargs")
model_kwargs.setdefault("vocab_size", vocab_size)
model_kwargs.setdefault("stack_cfg", None if backend != "neuron" else StackedTransformerLayer.default_config())
trainer_kwargs["model_cfg"] = model_config(**model_kwargs)
trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config(
max_step=trainer_kwargs["max_step"],
Expand Down

0 comments on commit eaf45a4

Please sign in to comment.