diff --git a/axlearn/common/flash_attention/neuron_attention_test.py b/axlearn/common/flash_attention/neuron_attention_test.py index 35737b1a9..370df3660 100644 --- a/axlearn/common/flash_attention/neuron_attention_test.py +++ b/axlearn/common/flash_attention/neuron_attention_test.py @@ -12,7 +12,7 @@ if jax.default_backend() != "neuron": - pytestmark = pytest.mark.skip(reason="Incompatible hardware, Neuron only test.") + pytestmark = pytest.mark.skip(reason="Incompatible hardware, AWS Neuron only test.") @pytest.mark.parametrize(