From c566bd5dfe9736136a809a86caf5d7b650ab556c Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 5 Nov 2024 16:19:47 -0500 Subject: [PATCH 1/8] fix --- composer/core/data_spec.py | 38 ++++++++++++++++++++++++++++++++----- composer/trainer/trainer.py | 9 ++++++--- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/composer/core/data_spec.py b/composer/core/data_spec.py index 670011e05b..c373999488 100644 --- a/composer/core/data_spec.py +++ b/composer/core/data_spec.py @@ -14,7 +14,7 @@ import torch.utils.data from torch.utils.data.distributed import DistributedSampler -from composer.utils import dist, ensure_tuple +from composer.utils import VersionedDeprecationWarning, dist, ensure_tuple if TYPE_CHECKING: from composer.core.types import Batch @@ -155,11 +155,20 @@ class DataSpec: num_tokens (int, optional): The total number of tokens in an epoch. This field is used by the :class:`.Timestamp` (training progress tracker). - device_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the - batch once it has been moved onto the device. For example, this function can be used for GPU-based + device_transforms ((Batch) -> Batch, optional): Deprecated argument. Please use ``batch_transforms`` for batch + level transformations on CPU and ``microbatch_transforms`` for microbatch level transformations on target + device. + + batch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the + batch before it is moved onto the device. For example, this function can be used for CPU-based normalization. It can modify the batch in-place, and it should return the modified batch. If not specified, the batch is not modified. + microbatch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the + microbatch before it is moved onto the device. For example, this function can be used for GPU-based + normalization. It can modify the microbatch in-place, and it should return the modified microbatch. If not + specified, the microbatch is not modified. + split_batch ((Batch, (int | float)) -> Sequence[Batch], optional): Function called by the :class:`.Trainer` to split a batch (the first parameter) into microbatches of a given size (the second parameter). If the ``dataloader`` yields batches not of type :class:`torch.Tensor`, Mapping, tuple, or list, then @@ -186,13 +195,32 @@ def __init__( num_samples: Optional[int] = None, num_tokens: Optional[int] = None, device_transforms: Optional[Callable[[Batch], Batch]] = None, + batch_transforms: Optional[Callable[[Batch], Batch]] = None, + microbatch_transforms: Optional[Callable[[Batch], Batch]] = None, split_batch: Optional[Callable[[Batch, Union[int, float]], Sequence[Batch]]] = None, get_num_samples_in_batch: Optional[Callable[[Batch], Union[int, float]]] = None, get_num_tokens_in_batch: Optional[Callable[[Batch], Union[int, dict[str, int]]]] = None, ) -> None: self.dataloader: Union[Iterable, torch.utils.data.DataLoader] = dataloader self.num_tokens = num_tokens - self.device_transforms = self._default_device_transforms if device_transforms is None else device_transforms + if device_transforms is not None: + if batch_transforms is not None: + raise ValueError( + 'Cannot specify both `device_transforms` and `batch_transforms`. Please use `batch_transforms` for ' + 'batch level transformations on CPU and `microbatch_transforms` for microbatch level transformations ' + 'on target device.', + ) + warnings.warn( + VersionedDeprecationWarning( + 'The `device_transforms` argument is deprecated. Please use `batch_transforms` for batch level ' + 'transformations on CPU and `microbatch_transforms` for microbatch level transformations on target ' + 'device.', + 'v0.28.0', + ) + ) + self.batch_transforms = device_transforms + self.batch_transforms = self._default_transforms if batch_transforms is None else batch_transforms + self.microbatch_transforms = self._default_transforms if microbatch_transforms is None else microbatch_transforms self.split_batch = default_split_batch if split_batch is None else split_batch self.get_num_samples_in_batch = self._default_get_num_samples_in_batch if get_num_samples_in_batch is None else get_num_samples_in_batch self._get_num_tokens_in_batch = self._default_get_num_tokens_in_batch if get_num_tokens_in_batch is None else get_num_tokens_in_batch @@ -242,7 +270,7 @@ def __init__( 'For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler.', ) - def _default_device_transforms(self, batch: Batch): + def _default_transforms(self, batch: Batch): return batch def _default_get_num_samples_in_batch(self, batch: Batch) -> int: diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 43d07d2542..98d58ec2ea 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2622,7 +2622,7 @@ def _train_loop(self) -> None: self._rng_state = None continue - self.state.batch = self._train_data_spec.device_transforms(self.state.batch) + self.state.batch = self._train_data_spec.batch_transforms(self.state.batch) rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch) rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch) @@ -3034,6 +3034,7 @@ def _train_microbatches( for microbatch_idx, self.state.batch in enumerate(microbatches): self.state.batch = self.state.device.batch_to_device(self.state.batch) + self.state.batch = self._train_data_spec.microbatch_transforms(self.state.batch) is_final_microbatch = microbatch_idx + 1 == len(microbatches) microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_batch_size, is_final_microbatch) @@ -3310,7 +3311,8 @@ def predict_batch_end(self, state: State, logger: Logger) -> None: self.state.batch = self.state.device.batch_to_device(self.state.batch) # Perform any device transforms - self.state.batch = data_spec.device_transforms(self.state.batch) + self.state.batch = data_spec.batch_transforms(self.state.batch) + self.state.batch = data_spec.microbatch_transforms(self.state.batch) # Count the batch size and num tokens before any events run rank_num_samples = data_spec.get_num_samples_in_batch(self.state.batch) @@ -3586,7 +3588,7 @@ def _eval_loop( ) for self.state.batch in self._iter_dataloader(TrainerMode.EVAL): - self.state.batch = data_spec.device_transforms(self.state.batch) + self.state.batch = data_spec.batch_transforms(self.state.batch) # Count the batch size and num tokens before any events run rank_num_samples = data_spec.get_num_samples_in_batch(self.state.batch) @@ -3616,6 +3618,7 @@ def _eval_loop( microbatches = data_spec.split_batch(device_batch, evaluator.device_eval_microbatch_size) for i, self.state.batch in enumerate(microbatches): self.state.batch = self.state.device.batch_to_device(self.state.batch) + self.state.batch = data_spec.microbatch_transforms(self.state.batch) last_microbatch = i == len(microbatches) - 1 skip_metric_update = False # Distributed samplers pad batches to be the same size. If using a From fd5b05958fe8b132c9f2a24b667f43f673d85a1b Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 5 Nov 2024 16:34:08 -0500 Subject: [PATCH 2/8] lint --- composer/core/data_spec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/core/data_spec.py b/composer/core/data_spec.py index c373999488..9c049b78a2 100644 --- a/composer/core/data_spec.py +++ b/composer/core/data_spec.py @@ -216,7 +216,7 @@ def __init__( 'transformations on CPU and `microbatch_transforms` for microbatch level transformations on target ' 'device.', 'v0.28.0', - ) + ), ) self.batch_transforms = device_transforms self.batch_transforms = self._default_transforms if batch_transforms is None else batch_transforms From 17efd4076f14ba8455745039b60c390eb50358ad Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Wed, 6 Nov 2024 14:53:59 -0500 Subject: [PATCH 3/8] fix docs --- composer/core/data_spec.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/composer/core/data_spec.py b/composer/core/data_spec.py index 9c049b78a2..297aa53e3c 100644 --- a/composer/core/data_spec.py +++ b/composer/core/data_spec.py @@ -126,16 +126,16 @@ def _default_split_batch(batch: Any, microbatch_size: Union[int, float]) -> Sequ class DataSpec: """Specifications for operating and training on data. - An example of constructing a :class:`DataSpec` object with a ``device_transforms`` + An example of constructing a :class:`DataSpec` object with a ``batch_transforms`` callable and then using it with :class:`~.Trainer`: .. doctest:: >>> # Construct DataSpec and subtract mean from the batch - >>> device_transform_fn = lambda xs, ys: (xs.sub_(xs.mean()), ys) - >>> train_dspec = DataSpec(train_dataloader, device_transforms=device_transform_fn) + >>> batch_transform_fn = lambda xs, ys: (xs.sub_(xs.mean()), ys) + >>> train_dspec = DataSpec(train_dataloader, batch_transforms=batch_transform_fn) >>> # The same function can be used for eval dataloader as well - >>> eval_dspec = DataSpec(eval_dataloader, device_transforms=device_transform_fn) + >>> eval_dspec = DataSpec(eval_dataloader, batch_transforms=batch_transform_fn) >>> # Use this DataSpec object to construct trainer >>> trainer = Trainer( ... model=model, From 3b748bc5ec950a7a19d4b9d0b9779419ee3bc9c1 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Wed, 6 Nov 2024 23:52:28 -0500 Subject: [PATCH 4/8] fix reviewer comments --- composer/core/data_spec.py | 2 +- composer/trainer/trainer.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/composer/core/data_spec.py b/composer/core/data_spec.py index 297aa53e3c..a1613cd392 100644 --- a/composer/core/data_spec.py +++ b/composer/core/data_spec.py @@ -215,7 +215,7 @@ def __init__( 'The `device_transforms` argument is deprecated. Please use `batch_transforms` for batch level ' 'transformations on CPU and `microbatch_transforms` for microbatch level transformations on target ' 'device.', - 'v0.28.0', + 'v0.29.0', ), ) self.batch_transforms = device_transforms diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 98d58ec2ea..202b47cee7 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -3307,11 +3307,10 @@ def predict_batch_end(self, state: State, logger: Logger) -> None: self.engine.run_event(Event.PREDICT_START) for self.state.batch in self._iter_dataloader(TrainerMode.PREDICT): + # Move the batch onto the device - self.state.batch = self.state.device.batch_to_device(self.state.batch) - - # Perform any device transforms self.state.batch = data_spec.batch_transforms(self.state.batch) + self.state.batch = self.state.device.batch_to_device(self.state.batch) self.state.batch = data_spec.microbatch_transforms(self.state.batch) # Count the batch size and num tokens before any events run From b1106efd5c39c26993001304ca264f35fb823a55 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Thu, 7 Nov 2024 00:05:33 -0500 Subject: [PATCH 5/8] add test --- composer/trainer/trainer.py | 2 +- tests/trainer/test_trainer.py | 24 +++++++++++++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 202b47cee7..c39a0b7b83 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -3307,7 +3307,7 @@ def predict_batch_end(self, state: State, logger: Logger) -> None: self.engine.run_event(Event.PREDICT_START) for self.state.batch in self._iter_dataloader(TrainerMode.PREDICT): - + # Move the batch onto the device self.state.batch = data_spec.batch_transforms(self.state.batch) self.state.batch = self.state.device.batch_to_device(self.state.batch) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 333fb2b719..f327631f33 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -19,7 +19,7 @@ from composer import Callback, Evaluator, Trainer from composer.algorithms import CutOut, LabelSmoothing -from composer.core import Event, Precision, State, Time, TimeUnit +from composer.core import Event, Precision, State, Time, TimeUnit, DataSpec from composer.devices import Device from composer.loggers import InMemoryLogger, Logger, RemoteUploaderDownloader from composer.loss import soft_cross_entropy @@ -1733,3 +1733,25 @@ def test_empty_eval_dataloader(self): max_duration='1ba', ) trainer.fit() + + +@device('cpu', 'gpu') +def test_transforms(device: str): + + def get_transform(device: str): + + def transform(batch: list[torch.Tensor]): + assert batch[0].device.type == device + return batch + + return transform + + dataloader = _get_classification_dataloader() + data_spec = DataSpec( + dataloader, + batch_transforms=get_transform('cpu'), + microbatch_transforms=get_transform(device), + ) + model = SimpleModel() + trainer = Trainer(model=model, train_dataloader=data_spec, max_duration='1ba') + trainer.fit() From 11d10ad4889a527f66c64f9d34f7c45bc9fecfc2 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Thu, 7 Nov 2024 00:05:45 -0500 Subject: [PATCH 6/8] lint --- tests/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index f327631f33..59e7d69bb1 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -19,7 +19,7 @@ from composer import Callback, Evaluator, Trainer from composer.algorithms import CutOut, LabelSmoothing -from composer.core import Event, Precision, State, Time, TimeUnit, DataSpec +from composer.core import DataSpec, Event, Precision, State, Time, TimeUnit from composer.devices import Device from composer.loggers import InMemoryLogger, Logger, RemoteUploaderDownloader from composer.loss import soft_cross_entropy From 8e46cbcc0b37b87ca84f9f07b581035f7cd2b63f Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Fri, 8 Nov 2024 23:57:51 -0500 Subject: [PATCH 7/8] fox --- tests/trainer/test_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 59e7d69bb1..c9cee20dc4 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1741,6 +1741,7 @@ def test_transforms(device: str): def get_transform(device: str): def transform(batch: list[torch.Tensor]): + device = 'cuda' if device == 'gpu' else device assert batch[0].device.type == device return batch From 6ebad5039c1a880d5949acc57b00b13efb20bbd4 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sat, 9 Nov 2024 01:53:41 -0500 Subject: [PATCH 8/8] lint --- tests/trainer/test_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c9cee20dc4..0c62c5c4cc 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1741,8 +1741,8 @@ def test_transforms(device: str): def get_transform(device: str): def transform(batch: list[torch.Tensor]): - device = 'cuda' if device == 'gpu' else device - assert batch[0].device.type == device + batch_device = 'gpu' if batch[0].device.type == 'cuda' else 'cpu' + assert batch_device == device return batch return transform