Skip to content

Commit

Permalink
fix host array test
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 11, 2024
1 parent 482b7fa commit e9e3279
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions axlearn/common/host_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,31 @@
host_to_global_device_array,
)


def is_supported(
platform: str,
mesh_shape: tuple[int, int],
global_batch_size: int,
data_partition: DataPartitionType,
):
if not is_supported_platform(platform):
return False, f'Platform "{platform}" not supported with devices {jax.devices()}.'
if not is_supported_mesh_shape(mesh_shape):
return False, f'Mesh shape "{mesh_shape}" not supported with device_count "{jax.device_count()}".'
if data_partition != DataPartitionType.REPLICATED:
return False, f'Data partition is "{data_partition}", expected "DataPartitionType.REPLICATED".'
if global_batch_size % jax.device_count() != 0:
return False, 'Global batch has to be divisible with number of devices. Global batch is "{global_batch_size}", number of devices is "{jax.device_count()}".'
return True , ""

return (
is_supported_platform(platform)
and is_supported_mesh_shape(mesh_shape)
and (
data_partition == DataPartitionType.REPLICATED
or global_batch_size % jax.device_count() == 0
)
)

class HostArrayTest(TestCase):
@parameterized.parameters(
itertools.product(
("cpu", "tpu"), # platform,
((1, 1), (4, 1), (2, 2), (8, 1), (4, 2)), # mesh_shape
(1, 16), # global_batch_size
(DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition
filter(
lambda params: is_supported(*params),
itertools.product(
("cpu", "tpu"), # platform,
((1, 1), (4, 1), (2, 2), (8, 1), (4, 2), (16, 4)), # mesh_shape
(1, 16), # global_batch_size
(DataPartitionType.FULL, DataPartitionType.REPLICATED, DataPartitionType,BATCH), # data_partition
),
)
)
def test_global_host_array_conversion(
Expand All @@ -63,9 +63,6 @@ def test_global_host_array_conversion(
global_batch_size,
data_partition,
)
supported, reason = is_supported(platform, mesh_shape, global_batch_size, data_partition)
if not supported:
pytest.skip(reason)
devices = mesh_utils.create_device_mesh(mesh_shape)
mesh = jax.sharding.Mesh(devices, ("data", "model"))
logging.info("Global mesh: %s", mesh)
Expand Down

0 comments on commit e9e3279

Please sign in to comment.