diff --git a/axlearn/common/host_array_test.py b/axlearn/common/host_array_test.py index 378e734d9..73b27ddc9 100644 --- a/axlearn/common/host_array_test.py +++ b/axlearn/common/host_array_test.py @@ -22,31 +22,31 @@ host_to_global_device_array, ) + def is_supported( platform: str, mesh_shape: tuple[int, int], global_batch_size: int, data_partition: DataPartitionType, ): - return ( - is_supported_platform(platform) - and is_supported_mesh_shape(mesh_shape) - and ( - data_partition == DataPartitionType.REPLICATED - or global_batch_size % jax.device_count() == 0 - ) - ) + if not is_supported_platform(platform): + return False, f'Platform "{platform}" not supported with devices {jax.devices()}.' + if not is_supported_mesh_shape(mesh_shape): + return False, f'Mesh shape "{mesh_shape}" not supported with device_count "{jax.device_count()}".' + if data_partition != DataPartitionType.REPLICATED: + return False, f'Data partition is "{data_partition}", expected "DataPartitionType.REPLICATED".' + if global_batch_size % jax.device_count() != 0: + return False, 'Global batch has to be divisible with number of devices. Global batch is "{global_batch_size}", number of devices is "{jax.device_count()}".' + return True , "" + class HostArrayTest(TestCase): @parameterized.parameters( - filter( - lambda params: is_supported(*params), - itertools.product( - ("cpu", "tpu"), # platform, - ((1, 1), (4, 1), (2, 2), (8, 1), (4, 2), (16, 4)), # mesh_shape - (1, 16), # global_batch_size - (DataPartitionType.FULL, DataPartitionType.REPLICATED, DataPartitionType,BATCH), # data_partition - ), + itertools.product( + ("cpu", "tpu"), # platform, + ((1, 1), (4, 1), (2, 2), (8, 1), (4, 2)), # mesh_shape + (1, 16), # global_batch_size + (DataPartitionType.FULL, DataPartitionType.REPLICATED, DataPartitionType.BATCH), # data_partition ) ) def test_global_host_array_conversion( @@ -63,6 +63,9 @@ def test_global_host_array_conversion( global_batch_size, data_partition, ) + supported, reason = is_supported(platform, mesh_shape, global_batch_size, data_partition) + if not supported: + pytest.skip(reason) devices = mesh_utils.create_device_mesh(mesh_shape) mesh = jax.sharding.Mesh(devices, ("data", "model")) logging.info("Global mesh: %s", mesh)