diff --git a/axlearn/cloud/gcp/tpu_health_check.py b/axlearn/cloud/gcp/tpu_health_check.py index f396e3da..06750392 100644 --- a/axlearn/cloud/gcp/tpu_health_check.py +++ b/axlearn/cloud/gcp/tpu_health_check.py @@ -31,6 +31,7 @@ from contextlib import contextmanager from datetime import datetime from typing import Literal, Optional, Union +import pytest import tensorflow as tf from absl import flags, logging diff --git a/axlearn/common/flash_attention/neuron_attention_test.py b/axlearn/common/flash_attention/neuron_attention_test.py index 2aa2251f..370df366 100644 --- a/axlearn/common/flash_attention/neuron_attention_test.py +++ b/axlearn/common/flash_attention/neuron_attention_test.py @@ -11,6 +11,10 @@ from axlearn.common.flash_attention.utils import mha_reference +if jax.default_backend() != "neuron": + pytestmark = pytest.mark.skip(reason="Incompatible hardware, AWS Neuron only test.") + + @pytest.mark.parametrize( "batch_size,seq_len,num_heads,per_head_dim", [ @@ -25,7 +29,6 @@ @pytest.mark.parametrize("use_fwd", [True, False]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32]) -@pytest.mark.skipif(jax.devices()[0].platform != "neuron", reason="Test only runs on Neuron.") def test_fwd_against_ref( batch_size: int, seq_len: int, @@ -81,7 +84,6 @@ def impl(q, k, v, bias): ) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32]) -@pytest.mark.skipif(jax.devices()[0].platform != "neuron", reason="Test only runs on Neuron.") def test_bwd_against_ref( batch_size: int, num_heads: int, diff --git a/axlearn/common/host_array_test.py b/axlearn/common/host_array_test.py index a0ec8467..52ea9f39 100644 --- a/axlearn/common/host_array_test.py +++ b/axlearn/common/host_array_test.py @@ -12,6 +12,7 @@ from absl.testing import absltest, parameterized from jax import numpy as jnp from jax.experimental import mesh_utils +import pytest from axlearn.common.test_utils import TestCase, is_supported_mesh_shape, is_supported_platform from axlearn.common.utils import ( diff --git a/axlearn/common/inference_test.py b/axlearn/common/inference_test.py index 3fc4e517..6b396337 100644 --- a/axlearn/common/inference_test.py +++ b/axlearn/common/inference_test.py @@ -20,6 +20,7 @@ from absl.testing import absltest, parameterized from jax.experimental import mesh_utils from jax.experimental.pjit import pjit +import pytest from axlearn.common import layers, test_utils, utils from axlearn.common.base_model import BaseModel @@ -168,13 +169,10 @@ def predict_batch(self, input_batch: NestedTensor) -> NestedTensor: def is_supported( platform: str, mesh_shape: tuple[int, int], - param_dtype: jnp.dtype, inference_dtype: Optional[jnp.dtype], global_batch_size: int, data_partition: DataPartitionType, - use_ema: bool = False, ): - del param_dtype, use_ema # not used # TODO(xuan-zou): jax 0.4.25 breaks bfloat16 on CPU due to high variance on # the final result (up to 10% precision diff), will re-enable when fixed. # NOTE: bfloat16 test on GPU is added and verified.