Skip to content

Commit

Permalink
fix regressions and messages
Browse files Browse the repository at this point in the history
  • Loading branch information
lipovsek-aws committed Dec 11, 2024
1 parent 3c03930 commit 80967cc
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
6 changes: 4 additions & 2 deletions axlearn/common/flash_attention/neuron_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from axlearn.common.flash_attention.utils import mha_reference


if jax.default_backend() != "neuron":
pytestmark = pytest.mark.skip(reason="Incompatible hardware, Neuron only test.")


@pytest.mark.parametrize(
"batch_size,seq_len,num_heads,per_head_dim",
[
Expand All @@ -25,7 +29,6 @@
@pytest.mark.parametrize("use_fwd", [True, False])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32])
@pytest.mark.skipif(jax.devices()[0].platform != "neuron", reason="Test only runs on Neuron.")
def test_fwd_against_ref(
batch_size: int,
seq_len: int,
Expand Down Expand Up @@ -81,7 +84,6 @@ def impl(q, k, v, bias):
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32])
@pytest.mark.skipif(jax.devices()[0].platform != "neuron", reason="Test only runs on Neuron.")
def test_bwd_against_ref(
batch_size: int,
num_heads: int,
Expand Down
1 change: 1 addition & 0 deletions axlearn/common/host_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from absl.testing import absltest, parameterized
from jax import numpy as jnp
from jax.experimental import mesh_utils
import pytest

from axlearn.common.test_utils import TestCase, is_supported_mesh_shape, is_supported_platform
from axlearn.common.utils import (
Expand Down
3 changes: 1 addition & 2 deletions axlearn/common/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl.testing import absltest, parameterized
from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit
import pytest

from axlearn.common import layers, test_utils, utils
from axlearn.common.base_model import BaseModel
Expand Down Expand Up @@ -168,11 +169,9 @@ def predict_batch(self, input_batch: NestedTensor) -> NestedTensor:
def is_supported(
platform: str,
mesh_shape: tuple[int, int],
param_dtype: jnp.dtype,
inference_dtype: Optional[jnp.dtype],
global_batch_size: int,
data_partition: DataPartitionType,
use_ema: bool = False,
):
del param_dtype, use_ema # not used
# TODO(xuan-zou): jax 0.4.25 breaks bfloat16 on CPU due to high variance on
Expand Down

0 comments on commit 80967cc

Please sign in to comment.