Skip to content

Commit

Permalink
Implement jit sharding.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617009006
  • Loading branch information
Xingyi Zhou authored and Scenic Authors committed Mar 19, 2024
1 parent 592970e commit 379f472
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 11 deletions.
73 changes: 62 additions & 11 deletions scenic/dataset_lib/flexio/flexio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
from scenic.dataset_lib import dataset_utils
from scenic.dataset_lib import datasets
import tensorflow as tf
Expand All @@ -37,6 +38,7 @@
Features = preprocess_spec.Features
TfFeature = Union[tf.io.FixedLenFeature, tf.io.VarLenFeature,
tf.io.FixedLenSequenceFeature]
PyTree = Any

# From grain/_src/core/constants.py
GRAIN_META_DATA = [
Expand Down Expand Up @@ -77,6 +79,35 @@ def tf2jax_dtype(dtype: tf.dtypes.DType) -> Union[jnp.dtype, tf.dtypes.DType]:



def shard_jit(data: PyTree, global_devices: np.ndarray) -> PyTree:
"""Shards data for use in jit-based pipelines.
Note that the order of global devices for sharding data is important and
should be compatible with device order used in the rest of the trainer for
models params, state, etc.
Args:
data: PyTree of data. Assumed to already contain Numpy arrays.
global_devices: List of global devices to shard over.
Returns:
Sharded data.
"""

def _shard_array(x):
mesh = jax.sharding.Mesh(global_devices, ('devices',))
sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec('devices'))
local_ds = mesh.local_devices

xs = jax.device_put(np.split(x, len(local_ds), axis=0), local_ds)

global_shape = (x.shape[0] * jax.process_count(), *x.shape[1:])
return jax.make_array_from_single_device_arrays(global_shape, sharding, xs)

return jax.tree_util.tree_map(_shard_array, data)


def apply_process_fn_with_populated_seed(ds: tf.data.Dataset,
preprocess_fn: Callable[[Features],
Features], *,
Expand Down Expand Up @@ -294,7 +325,7 @@ def _build_pipeline(
num_local_shards: int,
rng: Union[None, jnp.ndarray, tf.Tensor] = None,
global_rng: Union[None, jnp.ndarray, tf.Tensor] = None,
shuffle: bool = False
shuffle: bool = False,
) -> Optional[Union[tf.data.Dataset, Dict[str, tf.data.Dataset]]]:
"""Build a tf.data.Dataset pipeline using clu.deterministic_data or DMVR.
Expand All @@ -304,6 +335,8 @@ def _build_pipeline(
dataset_configs: Dataset configurations.
batch_size: Total batch size (sum for all devices).
num_local_shards: Number of local shards (usually num local devices).
<= 0 means we don't shard batches across devices, and use 1 batch dim
instead of 2.
rng: Per-host random seed (JAX format).
global_rng: Global random seed (JAX format).
shuffle: Whether to shuffle.
Expand Down Expand Up @@ -421,11 +454,14 @@ def _batch_and_prefetch(ds, batch_size):
return ds

# Batch to the desired output batch size:
if batch_size % num_local_shards != 0:
if num_local_shards > 0 and batch_size % num_local_shards != 0:
raise ValueError(
f'Local (host) batch size of {batch_size} is not divisible'
f'to num_local_shard={num_local_shards}.')
batch_dims = [num_local_shards, batch_size // num_local_shards]
if num_local_shards > 0:
batch_dims = [num_local_shards, batch_size // num_local_shards]
else:
batch_dims = [batch_size]
for batch_size in reversed(batch_dims):
if dataset_configs.get('padded_batch'):
ds = ds.padded_batch(batch_size, drop_remainder=True)
Expand Down Expand Up @@ -488,7 +524,8 @@ def get_iterator(
ds: Union[tf.data.Dataset, Dict[str, tf.data.Dataset]],
configs=ml_collections.ConfigDict,
*,
return_iterator: bool = False
return_iterator: bool = False,
devices_jit: Optional[np.ndarray] = None,
) -> Tuple[Union[Iterable[Any] | None, Dict[str, Iterable[Any] | None]], Union[
Tuple[Any, ...], Dict[str, Tuple[Any, ...]]], Union[int, Dict[str, int]]]:
"""Given a (dict of) Dataset object(s), returns iterators and metadata.
Expand All @@ -498,6 +535,7 @@ def get_iterator(
configs: A Config dict.
return_iterator: If False, the function returns a None instead of an
iterator.
devices_jit: List of devices to shard the data over for jit-based pipelines.
Returns:
Iterators, input specification and num_examples.
Expand All @@ -522,6 +560,10 @@ def _get_input_spec(ds):
else:
ds_it = iter(dataset)
ds_iter[dataset_name] = map(dataset_utils.tf_to_numpy, ds_it)
if devices_jit is not None:
ds_iter[dataset_name] = map(
functools.partial(shard_jit, global_devices=devices_jit),
ds_iter[dataset_name])
input_spec[dataset_name] = _get_input_spec(dataset)
# TODO(dehghani): Add support for having different input specs.
first_input_spec = list(input_spec.values())[0]
Expand All @@ -536,6 +578,10 @@ def _get_input_spec(ds):
else:
ds_it = iter(ds)
ds_iter = map(dataset_utils.tf_to_numpy, ds_it)
if devices_jit is not None:
ds_iter = map(
functools.partial(
shard_jit, global_devices=devices_jit), ds_iter)
total_examples = sum(list(total_examples.values()))
input_spec = _get_input_spec(ds)
else:
Expand All @@ -557,7 +603,8 @@ def get_dataset(
start_step: Optional[int] = None,
dtype_str: str = 'float32',
shuffle_seed: int = 0,
dataset_service_address: Optional[str] = None) -> dataset_utils.Dataset:
dataset_service_address: Optional[str] = None,
devices: Optional[np.ndarray] = None) -> dataset_utils.Dataset:
"""Returns generators for video datasets.
Args:
Expand All @@ -571,6 +618,7 @@ def get_dataset(
dtype_str: Data type of the image. Only 'float32' is currently supported.
shuffle_seed: Unsupported; use rng instead.
dataset_service_address: Unsupported; must be None.
devices: List of devices to shard the data over for jit-based pipelines.
Returns:
A dataset_utils.Dataset() which includes a train_iter, a valid_iter,
Expand Down Expand Up @@ -601,7 +649,7 @@ def get_dataset(
start_step=start_step,
dataset_configs=dataset_configs,
batch_size=batch_size,
num_local_shards=num_shards,
num_local_shards=num_shards if devices is None else -1,
rng=train_rng,
global_rng=global_rng,
shuffle=True)
Expand All @@ -613,19 +661,21 @@ def get_dataset(
start_step=0,
dataset_configs=dataset_configs,
batch_size=eval_batch_size,
num_local_shards=num_shards,
num_local_shards=num_shards if devices is None else -1,
global_rng=global_rng,
rng=eval_rng)

return_iterators = dataset_configs.get('return_iterators', True)
train_iter, train_input_spec, total_train_examples = get_iterator(
train_ds,
dataset_configs.get('train'),
return_iterator=return_iterators)
return_iterator=return_iterators,
devices_jit=devices)
eval_iter, eval_input_spec, total_eval_examples = get_iterator(
eval_ds,
dataset_configs.get('eval'),
return_iterator=return_iterators)
return_iterator=return_iterators,
devices_jit=devices)

# Testing dataset:
rng, test_rng = jax.random.split(rng)
Expand All @@ -634,14 +684,15 @@ def get_dataset(
start_step=0,
dataset_configs=dataset_configs,
batch_size=eval_batch_size,
num_local_shards=num_shards,
num_local_shards=num_shards if devices is None else -1,
global_rng=global_rng,
rng=test_rng)

test_iter, test_input_spec, total_test_examples = get_iterator(
test_ds,
dataset_configs.get('test'),
return_iterator=return_iterators)
return_iterator=return_iterators,
devices_jit=devices)

# Collect dataset metadata.
meta_data = {
Expand Down
69 changes: 69 additions & 0 deletions scenic/dataset_lib/flexio/tests/flexio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from grand_vision.preprocessing import image_ops
from grand_vision.preprocessing import modalities
import jax
from jax._src import array as jax_array
from jax.experimental import mesh_utils
import ml_collections
from scenic.dataset_lib.flexio import flexio
import tensorflow as tf
Expand Down Expand Up @@ -92,6 +94,73 @@ def test_tfds_datasets(self, train_tfds_name, eval_tfds_name):
self.assertDictEqual(
jax.tree_util.tree_map(lambda x: x.shape, valid_data), expected_shapes)

@parameterized.named_parameters(
('coco_coco', 'coco', 'coco'),
)
def test_sharded_tfds_datasets(self, train_tfds_name, eval_tfds_name):
"""Test TFDS dataset loading."""
dataset_configs = D({
'train': {
'sources': [D({
'source': 'tfds',
'tfds_name': train_tfds_name,
'split': 'train',
'shuffle_buffer_size': 2,
'cache': False,
'preproc_spec': 'decode_coco_example|crop_or_pad(64, 16)',
})],
'preproc_spec': 'crop_or_pad_meta_data(16, 16)',
},
'eval': {
'sources': [D({
'source': 'tfds',
'tfds_name': eval_tfds_name,
'split': 'validation',
'shuffle_buffer_size': 1,
'cache': False,
'preproc_spec': 'decode_coco_example',
})],
'preproc_spec': ('central_crop(64)'
'|crop_or_pad(64, 16)'
'|crop_or_pad_meta_data(16, 16)'),
},
'pp_libs': [ # We override the default ops.
'grand_vision.preprocessing.image_ops']
})
rng = jax.random.PRNGKey(0)
devices = mesh_utils.create_device_mesh((jax.device_count(),))
ds = flexio.get_dataset(
batch_size=8,
eval_batch_size=8,
num_shards=jax.local_device_count(),
rng=rng,
dataset_configs=dataset_configs,
devices=devices)
prefix_shape = (8,)
expected_shapes = {
modalities.ANNOTATION_ID: prefix_shape + (16,),
modalities.AREA: prefix_shape + (16,),
modalities.BOXES: prefix_shape + (16, 4),
modalities.CROWD: prefix_shape + (16,),
modalities.IMAGE: prefix_shape + (64, 64, 3),
modalities.IMAGE_ID: prefix_shape,
modalities.IMAGE_PADDING_MASK: prefix_shape + (64, 64),
modalities.INSTANCE_LABELS: prefix_shape + (16,),
modalities.ORIGINAL_SIZE: prefix_shape + (2,),
image_ops.SEED_KEY: prefix_shape + (2,)
}
train_data = next(ds.train_iter)
valid_data = next(ds.valid_iter)
self.assertDictEqual(
jax.tree_util.tree_map(lambda x: x.shape, train_data), expected_shapes)
self.assertDictEqual(
jax.tree_util.tree_map(lambda x: x.shape, valid_data), expected_shapes)
self.assertDictEqual(
jax.tree_util.tree_map(type, train_data),
jax.tree_util.tree_map(lambda x: jax_array.ArrayImpl, train_data))
self.assertDictEqual(
jax.tree_util.tree_map(type, valid_data),
jax.tree_util.tree_map(lambda x: jax_array.ArrayImpl, valid_data))

if __name__ == '__main__':
absltest.main()

0 comments on commit 379f472

Please sign in to comment.