Skip to content

Commit

Permalink
flash attention & input data sharding test
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 10, 2024
1 parent 8ee79cc commit 8de3b9c
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 0 deletions.
130 changes: 130 additions & 0 deletions axlearn/common/flash_attention/neuron_attention_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright © 2024 Amazon Inc.
"""Tests for Flash attention on Neuron. Tested on trn1."""
import functools

import chex
import jax
import jax.numpy as jnp
import pytest

from axlearn.common.flash_attention.neuron_attention import flash_attention
from axlearn.common.flash_attention.utils import mha_reference


@pytest.mark.parametrize(
"batch_size,seq_len,num_heads,per_head_dim",
[
(1, 2048, 1, 64),
(2, 2048, 2, 64),
(1, 2048, 1, 128),
(2, 2048, 2, 128),
(1, 2048, 8, 128),
(2, 2048, 8, 128),
],
)
@pytest.mark.parametrize("use_fwd", [True, False])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32])
@pytest.mark.skipif(jax.devices()[0].platform != "neuron", reason="Test only runs on Neuron.")
def test_fwd_against_ref(
batch_size: int,
seq_len: int,
num_heads: int,
per_head_dim: int,
use_fwd: bool,
causal: bool,
input_dtype: jnp.dtype,
):
sm_scale = 1.0 / (per_head_dim**0.5)
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)

bias = None
segment_ids = None

if use_fwd:

@jax.jit
def impl(q, k, v, bias):
fn = functools.partial(
flash_attention,
causal=causal,
softmax_scale=sm_scale,
)
out, _ = jax.vjp(fn, q, k, v, bias)
return out

else:
impl = functools.partial(
flash_attention,
causal=causal,
softmax_scale=sm_scale,
)

o = impl(q, k, v, bias)
o_ref = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale)
chex.assert_trees_all_close(o, o_ref, atol=0.05)


@pytest.mark.parametrize(
"batch_size,num_heads,seq_len,per_head_dim",
[
(1, 1, 2048, 64),
(2, 2, 2048, 64),
(1, 1, 2048, 128),
(2, 2, 2048, 128),
(1, 8, 2048, 128),
(2, 8, 2048, 128),
],
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32])
@pytest.mark.skipif(jax.devices()[0].platform != "neuron", reason="Test only runs on Neuron.")
def test_bwd_against_ref(
batch_size: int,
num_heads: int,
seq_len: int,
per_head_dim: int,
causal: bool,
input_dtype: jnp.dtype,
):
sm_scale = 1.0 / (per_head_dim**0.5)
q = jax.random.normal(
jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype
)
k = jax.random.normal(
jax.random.PRNGKey(1), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype
)
v = jax.random.normal(
jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype
)

bias = None
segment_ids = None

def fn(q, k, v, bias):
return flash_attention(
q,
k,
v,
bias,
causal=causal,
softmax_scale=sm_scale,
).sum()

def ref_fn(q, k, v, bias, segment_ids):
return mha_reference(
q,
k,
v,
bias,
segment_ids,
causal=causal,
softmax_scale=sm_scale,
).sum()

jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias)
jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids)
chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.07)
26 changes: 26 additions & 0 deletions axlearn/common/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.utils import (
DataPartitionType,
PHYSICAL_TO_LOGICAL_DISPATCH_KEY,
HybridMeshShape,
MeshShape,
Expand Down Expand Up @@ -1701,6 +1702,31 @@ def test_length(self):
class HostToGlobalArrayTest(TestCase):
"""Tests host_to_global_device_array."""

@pytest.mark.neuron
def test_partition_batch(self):
"""Test a case where each process produces a slice."""
device_count = jax.device_count()
process_count = jax.process_count()
print(f"{device_count=}, {process_count=}")
assert device_count > 1

global_shape = (device_count // 2, 1)
assert global_shape[0] % process_count == 0
per_feed_size = global_shape[0] // process_count
feed_index = jax.process_index()

with jax.sharding.Mesh(np.array(jax.devices()).reshape(device_count // 2, 2), ("x", "y")):
start = feed_index * per_feed_size
local_x = jnp.arange(start, start + per_feed_size)[:, None]

# Construct global array.
global_x = host_to_global_device_array(local_x, partition=DataPartitionType.BATCH, batch_axis_names="x")

# Compare against expected.
expected = jnp.arange(global_shape[0])[:, None]
self.assertEqual(jnp.mean(expected), jnp.mean(global_x))
self.assertNestedEqual(expected, replicate_to_local_data(global_x))

@pytest.mark.tpu
def test_partition_full(self):
"""Test a case where each process produces a slice."""
Expand Down

0 comments on commit 8de3b9c

Please sign in to comment.