diff --git a/composer/core/data_spec.py b/composer/core/data_spec.py index 670011e05b..a1613cd392 100644 --- a/composer/core/data_spec.py +++ b/composer/core/data_spec.py @@ -14,7 +14,7 @@ import torch.utils.data from torch.utils.data.distributed import DistributedSampler -from composer.utils import dist, ensure_tuple +from composer.utils import VersionedDeprecationWarning, dist, ensure_tuple if TYPE_CHECKING: from composer.core.types import Batch @@ -126,16 +126,16 @@ def _default_split_batch(batch: Any, microbatch_size: Union[int, float]) -> Sequ class DataSpec: """Specifications for operating and training on data. - An example of constructing a :class:`DataSpec` object with a ``device_transforms`` + An example of constructing a :class:`DataSpec` object with a ``batch_transforms`` callable and then using it with :class:`~.Trainer`: .. doctest:: >>> # Construct DataSpec and subtract mean from the batch - >>> device_transform_fn = lambda xs, ys: (xs.sub_(xs.mean()), ys) - >>> train_dspec = DataSpec(train_dataloader, device_transforms=device_transform_fn) + >>> batch_transform_fn = lambda xs, ys: (xs.sub_(xs.mean()), ys) + >>> train_dspec = DataSpec(train_dataloader, batch_transforms=batch_transform_fn) >>> # The same function can be used for eval dataloader as well - >>> eval_dspec = DataSpec(eval_dataloader, device_transforms=device_transform_fn) + >>> eval_dspec = DataSpec(eval_dataloader, batch_transforms=batch_transform_fn) >>> # Use this DataSpec object to construct trainer >>> trainer = Trainer( ... model=model, @@ -155,11 +155,20 @@ class DataSpec: num_tokens (int, optional): The total number of tokens in an epoch. This field is used by the :class:`.Timestamp` (training progress tracker). - device_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the - batch once it has been moved onto the device. For example, this function can be used for GPU-based + device_transforms ((Batch) -> Batch, optional): Deprecated argument. Please use ``batch_transforms`` for batch + level transformations on CPU and ``microbatch_transforms`` for microbatch level transformations on target + device. + + batch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the + batch before it is moved onto the device. For example, this function can be used for CPU-based normalization. It can modify the batch in-place, and it should return the modified batch. If not specified, the batch is not modified. + microbatch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the + microbatch before it is moved onto the device. For example, this function can be used for GPU-based + normalization. It can modify the microbatch in-place, and it should return the modified microbatch. If not + specified, the microbatch is not modified. + split_batch ((Batch, (int | float)) -> Sequence[Batch], optional): Function called by the :class:`.Trainer` to split a batch (the first parameter) into microbatches of a given size (the second parameter). If the ``dataloader`` yields batches not of type :class:`torch.Tensor`, Mapping, tuple, or list, then @@ -186,13 +195,32 @@ def __init__( num_samples: Optional[int] = None, num_tokens: Optional[int] = None, device_transforms: Optional[Callable[[Batch], Batch]] = None, + batch_transforms: Optional[Callable[[Batch], Batch]] = None, + microbatch_transforms: Optional[Callable[[Batch], Batch]] = None, split_batch: Optional[Callable[[Batch, Union[int, float]], Sequence[Batch]]] = None, get_num_samples_in_batch: Optional[Callable[[Batch], Union[int, float]]] = None, get_num_tokens_in_batch: Optional[Callable[[Batch], Union[int, dict[str, int]]]] = None, ) -> None: self.dataloader: Union[Iterable, torch.utils.data.DataLoader] = dataloader self.num_tokens = num_tokens - self.device_transforms = self._default_device_transforms if device_transforms is None else device_transforms + if device_transforms is not None: + if batch_transforms is not None: + raise ValueError( + 'Cannot specify both `device_transforms` and `batch_transforms`. Please use `batch_transforms` for ' + 'batch level transformations on CPU and `microbatch_transforms` for microbatch level transformations ' + 'on target device.', + ) + warnings.warn( + VersionedDeprecationWarning( + 'The `device_transforms` argument is deprecated. Please use `batch_transforms` for batch level ' + 'transformations on CPU and `microbatch_transforms` for microbatch level transformations on target ' + 'device.', + 'v0.29.0', + ), + ) + self.batch_transforms = device_transforms + self.batch_transforms = self._default_transforms if batch_transforms is None else batch_transforms + self.microbatch_transforms = self._default_transforms if microbatch_transforms is None else microbatch_transforms self.split_batch = default_split_batch if split_batch is None else split_batch self.get_num_samples_in_batch = self._default_get_num_samples_in_batch if get_num_samples_in_batch is None else get_num_samples_in_batch self._get_num_tokens_in_batch = self._default_get_num_tokens_in_batch if get_num_tokens_in_batch is None else get_num_tokens_in_batch @@ -242,7 +270,7 @@ def __init__( 'For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler.', ) - def _default_device_transforms(self, batch: Batch): + def _default_transforms(self, batch: Batch): return batch def _default_get_num_samples_in_batch(self, batch: Batch) -> int: diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 43d07d2542..c39a0b7b83 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2622,7 +2622,7 @@ def _train_loop(self) -> None: self._rng_state = None continue - self.state.batch = self._train_data_spec.device_transforms(self.state.batch) + self.state.batch = self._train_data_spec.batch_transforms(self.state.batch) rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch) rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch) @@ -3034,6 +3034,7 @@ def _train_microbatches( for microbatch_idx, self.state.batch in enumerate(microbatches): self.state.batch = self.state.device.batch_to_device(self.state.batch) + self.state.batch = self._train_data_spec.microbatch_transforms(self.state.batch) is_final_microbatch = microbatch_idx + 1 == len(microbatches) microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_batch_size, is_final_microbatch) @@ -3306,11 +3307,11 @@ def predict_batch_end(self, state: State, logger: Logger) -> None: self.engine.run_event(Event.PREDICT_START) for self.state.batch in self._iter_dataloader(TrainerMode.PREDICT): + # Move the batch onto the device + self.state.batch = data_spec.batch_transforms(self.state.batch) self.state.batch = self.state.device.batch_to_device(self.state.batch) - - # Perform any device transforms - self.state.batch = data_spec.device_transforms(self.state.batch) + self.state.batch = data_spec.microbatch_transforms(self.state.batch) # Count the batch size and num tokens before any events run rank_num_samples = data_spec.get_num_samples_in_batch(self.state.batch) @@ -3586,7 +3587,7 @@ def _eval_loop( ) for self.state.batch in self._iter_dataloader(TrainerMode.EVAL): - self.state.batch = data_spec.device_transforms(self.state.batch) + self.state.batch = data_spec.batch_transforms(self.state.batch) # Count the batch size and num tokens before any events run rank_num_samples = data_spec.get_num_samples_in_batch(self.state.batch) @@ -3616,6 +3617,7 @@ def _eval_loop( microbatches = data_spec.split_batch(device_batch, evaluator.device_eval_microbatch_size) for i, self.state.batch in enumerate(microbatches): self.state.batch = self.state.device.batch_to_device(self.state.batch) + self.state.batch = data_spec.microbatch_transforms(self.state.batch) last_microbatch = i == len(microbatches) - 1 skip_metric_update = False # Distributed samplers pad batches to be the same size. If using a diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 333fb2b719..0c62c5c4cc 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -19,7 +19,7 @@ from composer import Callback, Evaluator, Trainer from composer.algorithms import CutOut, LabelSmoothing -from composer.core import Event, Precision, State, Time, TimeUnit +from composer.core import DataSpec, Event, Precision, State, Time, TimeUnit from composer.devices import Device from composer.loggers import InMemoryLogger, Logger, RemoteUploaderDownloader from composer.loss import soft_cross_entropy @@ -1733,3 +1733,26 @@ def test_empty_eval_dataloader(self): max_duration='1ba', ) trainer.fit() + + +@device('cpu', 'gpu') +def test_transforms(device: str): + + def get_transform(device: str): + + def transform(batch: list[torch.Tensor]): + batch_device = 'gpu' if batch[0].device.type == 'cuda' else 'cpu' + assert batch_device == device + return batch + + return transform + + dataloader = _get_classification_dataloader() + data_spec = DataSpec( + dataloader, + batch_transforms=get_transform('cpu'), + microbatch_transforms=get_transform(device), + ) + model = SimpleModel() + trainer = Trainer(model=model, train_dataloader=data_spec, max_duration='1ba') + trainer.fit()