diff --git a/axlearn/common/flash_attention/neuron_attention.py b/axlearn/common/flash_attention/neuron_attention.py index 65d68531d..286754ed2 100644 --- a/axlearn/common/flash_attention/neuron_attention.py +++ b/axlearn/common/flash_attention/neuron_attention.py @@ -10,111 +10,154 @@ from jax_neuronx import nki_call from neuronxcc.nki._private_kernels.legacy.attention import flash_attn_bwd, flash_fwd -if 'LNC' not in os.environ: +# enable buffer donation in neuron +jax._src.interpreters.mlir._platforms_with_donation.append("neuron") + +if "LNC" not in os.environ: raise ValueError("LNC environment variable is not set") -cores_per_lnc = os.environ['LNC'] -if cores_per_lnc == '2': +cores_per_lnc = os.environ["LNC"] +if cores_per_lnc == "2": use_lnc = True -elif cores_per_lnc == '1': +elif cores_per_lnc == "1": use_lnc = False else: raise ValueError("LNC environment variable must be set to '1' or '2'") -disable_sharded_attn_kernel = os.environ.get('DISABLE_SHARDED_ATTN_KERNEL') -if disable_sharded_attn_kernel is not None: - use_lnc = False - if use_lnc: - from neuronxcc.nki._private_kernels.attention import flash_fwd_shardable, flash_attn_bwd_shardable + from neuronxcc.nki._private_kernels.attention import ( + flash_fwd_shardable, + flash_attn_bwd_shardable, + ) from neuronxcc.starfish.penguin.targets.nki.private_api import vnc from jax import custom_vjp -@partial(custom_vjp, nondiff_argnums=(3,4)) -def flash_attention(query, key, value, causal, softmax_scale): - out, _ = _mha_forward(query, key, value, causal, softmax_scale) - return out - -def _mha_forward(query, key, value, causal, softmax_scale): - # Get the batch size, sequence lengths, number of heads, and hidden dimension - batch_size, q_seq_len, num_heads, d_model = query.shape - _, kv_seq_len, _, _ = key.shape - - # Transpose the query, key, and value tensors - q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len] - k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len] - v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model] - - # Create the output buffer - import neuronxcc.nki.language as nl - attn_output_shape = jax.ShapeDtypeStruct((batch_size, num_heads, q_seq_len, d_model), dtype=query.dtype) - lse_shape = jax.ShapeDtypeStruct((batch_size, num_heads, nl.tile_size.pmax, q_seq_len // nl.tile_size.pmax), dtype=jnp.float32) - seed = jnp.array([1]) - - # Call the NKI kernel - if os.environ.get('ENABLE_NEW_UNSHARDED_ATTN_KERNEL'): - from neuronxcc.nki.kernels.attention import flash_attn_bwd, flash_fwd - import neuronxcc.nki.language as nl - - assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f'unexpect num_heads: {num_heads}' - attn_output, lse = flash_fwd[batch_size, nl.nc(2) * (num_heads//2)](q, k, v, seed, use_causal_mask=causal, softmax_scale=softmax_scale, mixed_precision=True, dropout_p=0.0) - else: - from neuronxcc.nki._private_kernels.legacy.attention import flash_fwd - from neuronxcc.nki._private_kernels.attention import flash_fwd_shardable - from neuronxcc.starfish.penguin.targets.nki.private_api import vnc - attn_output, lse = nki_call( - partial(flash_fwd_shardable if use_lnc else flash_fwd, use_causal_mask=causal, softmax_scale=softmax_scale, mixed_precision=True, dropout_p=0.0), - q, k, v, seed, - out_shape=(attn_output_shape, lse_shape), - grid=(batch_size, num_heads, vnc(2)) if use_lnc else (batch_size, num_heads) - ) - # Transpose the output back to the original shape - attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model] - - return attn_output, (lse, attn_output, q, k, v) + +@partial(custom_vjp, nondiff_argnums=(4, 5)) +def flash_attention(query, key, value, bias, causal, softmax_scale): + # NOTE : Merge with upstream. Old code supports both 2d and 4d bias but upstream code only supports 4d. + # We no longer need 2d logit_bias but should sync how we merge this check with upstream. + # assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" + out, _ = _mha_forward(query, key, value, bias, causal, softmax_scale) + return out + + +def _mha_forward(query, key, value, bias, causal, softmax_scale): + # Get the batch size, sequence lengths, number of heads, and hidden dimension + batch_size, q_seq_len, num_heads, d_model = query.shape + _, kv_seq_len, _, _ = key.shape + + # Transpose the query, key, and value tensors + q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len] + k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len] + v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model] + + import neuronxcc.nki.language as nl + from neuronxcc.nki.kernels.attention import flash_attn_bwd, flash_fwd + + # Create the output buffer + attn_output_shape = jax.ShapeDtypeStruct( + (batch_size, num_heads, q_seq_len, d_model), dtype=query.dtype + ) + lse_shape = jax.ShapeDtypeStruct( + (batch_size, num_heads, nl.tile_size.pmax, q_seq_len // nl.tile_size.pmax), + dtype=jnp.float32, + ) + seed = jnp.array([1]) + + # Call the NKI kernel + + assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f"unexpect num_heads: {num_heads}" + + if bias != None: + attn_output, lse = flash_fwd[batch_size, nl.nc(2) * (num_heads // 2)]( + q, + k, + v, + seed, + bias, + use_causal_mask=causal, + softmax_scale=softmax_scale, + mixed_precision=True, + dropout_p=0.0, + ) + else: + attn_output, lse = flash_fwd[batch_size, nl.nc(2) * (num_heads // 2)]( + q, + k, + v, + seed, + use_causal_mask=causal, + softmax_scale=softmax_scale, + mixed_precision=True, + dropout_p=0.0, + ) + # Transpose the output back to the original shape + attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model] + + return attn_output, (lse, attn_output, q, k, v, bias) + def _mha_backward(causal, softmax_scale, res, d_attn_output): - lse, o, q, k, v = res - batch_size, num_heads, d_model, seq_len = q.shape - _, kv_seq_len, _, _ = k.shape - - # Transpose the input tensors - o = o.transpose(0, 2, 3, 1) - dy = d_attn_output.transpose(0, 2, 3, 1) - - # Transpose v tensor - v = jnp.transpose(v, axes=(0, 1, 3, 2)) - # Create the output buffer shapes - d_query_shape = jax.ShapeDtypeStruct(q.shape, q.dtype) - d_key_shape = jax.ShapeDtypeStruct(k.shape, k.dtype) - d_value_shape = jax.ShapeDtypeStruct(v.shape, v.dtype) - seed = jnp.array([1]) - - # Call the NKI kernel - if os.environ.get('ENABLE_NEW_UNSHARDED_ATTN_KERNEL'): - from neuronxcc.nki.kernels.attention import flash_attn_bwd - import neuronxcc.nki.language as nl - assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f'unexpected num_heads: {num_heads}' - d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(2) * (num_heads//2)](q, k, v, o, dy, lse, seed, use_causal_mask=causal, mixed_precision=True, dropout_p=0.0, softmax_scale=softmax_scale) - else: - from neuronxcc.nki._private_kernels.legacy.attention import flash_attn_bwd - from neuronxcc.nki._private_kernels.attention import flash_attn_bwd_shardable - from neuronxcc.starfish.penguin.targets.nki.private_api import vnc - d_query, d_key, d_value = nki_call( - partial(flash_attn_bwd_shardable if use_lnc else flash_attn_bwd, use_causal_mask=causal, mixed_precision=True, dropout_p=0.0, softmax_scale=softmax_scale), - q, k, v, o, dy, lse, seed, - out_shape=[d_query_shape, d_key_shape, d_value_shape], - grid=(batch_size, num_heads, vnc(2)) if use_lnc else (batch_size, num_heads) - ) - - # Batch seq_len heads, head_dim - # Transpose the gradients back to the original shape - d_query = d_query.transpose(0, 3, 1, 2) - d_key = d_key.transpose(0, 3, 1, 2) - d_value = d_value.transpose(0, 3, 1, 2) - - return d_query, d_key, d_value + lse, o, q, k, v, bias = res + batch_size, num_heads, d_model, seq_len = q.shape + _, kv_seq_len, _, _ = k.shape + + # Transpose the input tensors + o = o.transpose(0, 2, 3, 1) + dy = d_attn_output.transpose(0, 2, 3, 1) + + # Transpose v tensor + v = jnp.transpose(v, axes=(0, 1, 3, 2)) + # Create the output buffer shapes + d_query_shape = jax.ShapeDtypeStruct(q.shape, q.dtype) + d_key_shape = jax.ShapeDtypeStruct(k.shape, k.dtype) + d_value_shape = jax.ShapeDtypeStruct(v.shape, v.dtype) + seed = jnp.array([1]) + + from neuronxcc.nki.kernels.attention import flash_attn_bwd + import neuronxcc.nki.language as nl + + # Call the NKI kernel + if bias != None: + assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f"unexpected num_heads: {num_heads}" + d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(2) * (num_heads // 2)]( + q, + k, + v, + o, + dy, + lse, + seed, + bias, + use_causal_mask=causal, + mixed_precision=True, + dropout_p=0.0, + softmax_scale=softmax_scale, + ) + else: + d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(2) * (num_heads // 2)]( + q, + k, + v, + o, + dy, + lse, + seed, + use_causal_mask=causal, + mixed_precision=True, + dropout_p=0.0, + softmax_scale=softmax_scale, + ) + + # Batch seq_len heads, head_dim + # Transpose the gradients back to the original shape + d_query = d_query.transpose(0, 3, 1, 2) + d_key = d_key.transpose(0, 3, 1, 2) + d_value = d_value.transpose(0, 3, 1, 2) + + return d_query, d_key, d_value, None -flash_attention.defvjp(_mha_forward, _mha_backward) +flash_attention.defvjp(_mha_forward, _mha_backward) diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index 37baab3b1..785ceadb5 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -171,7 +171,7 @@ def jit_attn(query, key, value, bias, segment_ids): if segment_ids != None: raise Exception("Sequence Packing is not supported on Neuron backend") return neuron_flash_attention( - query, key, value, causal, softmax_scale) + query, key, value, bias, causal, softmax_scale) return jit_attn diff --git a/axlearn/common/gda_test.py b/axlearn/common/gda_test.py index edf415517..03b820384 100644 --- a/axlearn/common/gda_test.py +++ b/axlearn/common/gda_test.py @@ -28,7 +28,7 @@ class GDATest(TestCase): itertools.product( ((1, 1), (8, 1), (4, 2)), # mesh_shape (1, 16), # per_host_batch_size - (DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition + (DataPartitionType.FULL, DataPartitionType.REPLICATED, DataPartitionType.BATCH), # data_partition ) ) def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition): @@ -43,11 +43,16 @@ def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition devices = mesh_utils.create_device_mesh(mesh_shape) if data_partition == DataPartitionType.FULL: global_batch_size = per_host_batch_size * jax.process_count() + elif data_partition == DataPartitionType.BATCH: + global_batch_size = per_host_batch_size * jax.process_count() else: assert data_partition == DataPartitionType.REPLICATED global_batch_size = per_host_batch_size if data_partition == DataPartitionType.FULL and global_batch_size < jax.device_count(): return + # first axis is assumed to be batch axis + if data_partition == DataPartitionType.BATCH and global_batch_size % mesh_shape[0] == 0: + return per_host_input_batch = dict(x=jnp.zeros((per_host_batch_size, 8), dtype=jnp.float32)) with jax.sharding.Mesh(devices, ("data", "model")): global_input_batch = host_to_global_device_array( diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 952ea2653..ec1e6e207 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -646,7 +646,7 @@ def make_gda(x, partition_spec): elif partition == DataPartitionType.REPLICATED: global_shape = (x.shape[0], *x.shape[1:]) elif partition == DataPartitionType.BATCH: - global_shape = (x.shape[0], *x.shape[1:]) + global_shape = (x.shape[0] * process_count, *x.shape[1:]) else: raise NotImplementedError(f"Unsupported partition: {partition}") return jax.make_array_from_process_local_data( diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index eb4c921ba..f0b5a3a0b 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -125,8 +125,7 @@ def get_trainer_kwargs( return {} max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch max_sequence_length = MAX_SEQUENCE_LENGTH[version] - # train_batch_size = tokens_per_batch // max_sequence_length - train_batch_size = 16 + train_batch_size = tokens_per_batch // max_sequence_length # Whether to use grouped query attention. num_kv_heads = None @@ -153,7 +152,6 @@ def get_trainer_kwargs( rope_theta=rope_theta, shared_lm_head=True, flash_attention=flash_attention, - stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=6e-4, weight_decay=0.01), max_sequence_length=64, @@ -210,7 +208,6 @@ def get_trainer_kwargs( rope_theta=rope_theta, shared_lm_head=True, flash_attention=flash_attention, - stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, @@ -393,18 +390,17 @@ def get_trainer_kwargs( hidden_dim=128 * 64, num_heads=64, # No GQA support in V1 models, so num_kv_heads is the same as num_heads. - num_kv_heads=None,# if version == Version.V1 else 8, + num_kv_heads=None if version == Version.V1 else 8, # TODO(kelvin-zou): Remove the perf numbers for V5e (OOM). ffn_dim=scaled_hidden_dim(scale=3.5, round_up_to_multiples_of=256), rope_theta=rope_theta, # shared_lm_head=False, shared_lm_head=True, flash_attention=True, - stack_cfg=None if backend != "neuron" else StackedTransformerLayer.default_config(), ), learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=int((jax.device_count()/4)), input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(fsdp=-1), @@ -447,6 +443,7 @@ def get_trainer_kwargs( raise NotImplementedError(f"Unknown model size {model_size}.") model_kwargs = trainer_kwargs.pop("model_kwargs") model_kwargs.setdefault("vocab_size", vocab_size) + model_kwargs.setdefault("stack_cfg", None if backend != "neuron" else StackedTransformerLayer.default_config()) trainer_kwargs["model_cfg"] = model_config(**model_kwargs) trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config( max_step=trainer_kwargs["max_step"],