Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix regressions and messages #29

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions axlearn/cloud/gcp/tpu_health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from contextlib import contextmanager
from datetime import datetime
from typing import Literal, Optional, Union
import pytest
patrick-toulme marked this conversation as resolved.
Show resolved Hide resolved

import tensorflow as tf
from absl import flags, logging
Expand Down
6 changes: 4 additions & 2 deletions axlearn/common/flash_attention/neuron_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from axlearn.common.flash_attention.utils import mha_reference


if jax.default_backend() != "neuron":
pytestmark = pytest.mark.skip(reason="Incompatible hardware, AWS Neuron only test.")

patrick-toulme marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize(
"batch_size,seq_len,num_heads,per_head_dim",
[
Expand All @@ -25,7 +29,6 @@
@pytest.mark.parametrize("use_fwd", [True, False])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32])
@pytest.mark.skipif(jax.devices()[0].platform != "neuron", reason="Test only runs on Neuron.")
def test_fwd_against_ref(
batch_size: int,
seq_len: int,
Expand Down Expand Up @@ -81,7 +84,6 @@ def impl(q, k, v, bias):
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32])
@pytest.mark.skipif(jax.devices()[0].platform != "neuron", reason="Test only runs on Neuron.")
def test_bwd_against_ref(
batch_size: int,
num_heads: int,
Expand Down
1 change: 1 addition & 0 deletions axlearn/common/host_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from absl.testing import absltest, parameterized
from jax import numpy as jnp
from jax.experimental import mesh_utils
import pytest
patrick-toulme marked this conversation as resolved.
Show resolved Hide resolved

from axlearn.common.test_utils import TestCase, is_supported_mesh_shape, is_supported_platform
from axlearn.common.utils import (
Expand Down
4 changes: 1 addition & 3 deletions axlearn/common/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl.testing import absltest, parameterized
from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit
import pytest
patrick-toulme marked this conversation as resolved.
Show resolved Hide resolved

from axlearn.common import layers, test_utils, utils
from axlearn.common.base_model import BaseModel
Expand Down Expand Up @@ -168,13 +169,10 @@ def predict_batch(self, input_batch: NestedTensor) -> NestedTensor:
def is_supported(
platform: str,
mesh_shape: tuple[int, int],
param_dtype: jnp.dtype,
patrick-toulme marked this conversation as resolved.
Show resolved Hide resolved
inference_dtype: Optional[jnp.dtype],
global_batch_size: int,
data_partition: DataPartitionType,
use_ema: bool = False,
):
del param_dtype, use_ema # not used
patrick-toulme marked this conversation as resolved.
Show resolved Hide resolved
# TODO(xuan-zou): jax 0.4.25 breaks bfloat16 on CPU due to high variance on
# the final result (up to 10% precision diff), will re-enable when fixed.
# NOTE: bfloat16 test on GPU is added and verified.
Expand Down
Loading