Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch/microbatch transforms #3703

Merged
merged 8 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.utils.data
from torch.utils.data.distributed import DistributedSampler

from composer.utils import dist, ensure_tuple
from composer.utils import VersionedDeprecationWarning, dist, ensure_tuple

if TYPE_CHECKING:
from composer.core.types import Batch
Expand Down Expand Up @@ -155,11 +155,20 @@ class DataSpec:
num_tokens (int, optional): The total number of tokens in an epoch. This field is used by the
:class:`.Timestamp` (training progress tracker).

device_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
batch once it has been moved onto the device. For example, this function can be used for GPU-based
device_transforms ((Batch) -> Batch, optional): Deprecated argument. Please use ``batch_transforms`` for batch
level transformations on CPU and ``microbatch_transforms`` for microbatch level transformations on target
device.
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved

batch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
batch before it is moved onto the device. For example, this function can be used for CPU-based
normalization. It can modify the batch in-place, and it should return the modified batch. If not specified,
the batch is not modified.

microbatch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
microbatch before it is moved onto the device. For example, this function can be used for GPU-based
normalization. It can modify the microbatch in-place, and it should return the modified microbatch. If not
specified, the microbatch is not modified.

split_batch ((Batch, (int | float)) -> Sequence[Batch], optional): Function called by the :class:`.Trainer` to
split a batch (the first parameter) into microbatches of a given size (the second parameter). If
the ``dataloader`` yields batches not of type :class:`torch.Tensor`, Mapping, tuple, or list, then
Expand All @@ -186,13 +195,32 @@ def __init__(
num_samples: Optional[int] = None,
num_tokens: Optional[int] = None,
device_transforms: Optional[Callable[[Batch], Batch]] = None,
batch_transforms: Optional[Callable[[Batch], Batch]] = None,
microbatch_transforms: Optional[Callable[[Batch], Batch]] = None,
split_batch: Optional[Callable[[Batch, Union[int, float]], Sequence[Batch]]] = None,
get_num_samples_in_batch: Optional[Callable[[Batch], Union[int, float]]] = None,
get_num_tokens_in_batch: Optional[Callable[[Batch], Union[int, dict[str, int]]]] = None,
) -> None:
self.dataloader: Union[Iterable, torch.utils.data.DataLoader] = dataloader
self.num_tokens = num_tokens
self.device_transforms = self._default_device_transforms if device_transforms is None else device_transforms
if device_transforms is not None:
if batch_transforms is not None:
raise ValueError(
'Cannot specify both `device_transforms` and `batch_transforms`. Please use `batch_transforms` for '
'batch level transformations on CPU and `microbatch_transforms` for microbatch level transformations '
'on target device.',
)
warnings.warn(
VersionedDeprecationWarning(
'The `device_transforms` argument is deprecated. Please use `batch_transforms` for batch level '
'transformations on CPU and `microbatch_transforms` for microbatch level transformations on target '
'device.',
'v0.28.0',
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
),
)
self.batch_transforms = device_transforms
self.batch_transforms = self._default_transforms if batch_transforms is None else batch_transforms
self.microbatch_transforms = self._default_transforms if microbatch_transforms is None else microbatch_transforms
self.split_batch = default_split_batch if split_batch is None else split_batch
self.get_num_samples_in_batch = self._default_get_num_samples_in_batch if get_num_samples_in_batch is None else get_num_samples_in_batch
self._get_num_tokens_in_batch = self._default_get_num_tokens_in_batch if get_num_tokens_in_batch is None else get_num_tokens_in_batch
Expand Down Expand Up @@ -242,7 +270,7 @@ def __init__(
'For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler.',
)

def _default_device_transforms(self, batch: Batch):
def _default_transforms(self, batch: Batch):
return batch

def _default_get_num_samples_in_batch(self, batch: Batch) -> int:
Expand Down
9 changes: 6 additions & 3 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2622,7 +2622,7 @@ def _train_loop(self) -> None:
self._rng_state = None
continue

self.state.batch = self._train_data_spec.device_transforms(self.state.batch)
self.state.batch = self._train_data_spec.batch_transforms(self.state.batch)
rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch)

Expand Down Expand Up @@ -3034,6 +3034,7 @@ def _train_microbatches(

for microbatch_idx, self.state.batch in enumerate(microbatches):
self.state.batch = self.state.device.batch_to_device(self.state.batch)
self.state.batch = self._train_data_spec.microbatch_transforms(self.state.batch)
is_final_microbatch = microbatch_idx + 1 == len(microbatches)
microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_batch_size, is_final_microbatch)

Expand Down Expand Up @@ -3310,7 +3311,8 @@ def predict_batch_end(self, state: State, logger: Logger) -> None:
self.state.batch = self.state.device.batch_to_device(self.state.batch)

# Perform any device transforms
self.state.batch = data_spec.device_transforms(self.state.batch)
self.state.batch = data_spec.batch_transforms(self.state.batch)
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
self.state.batch = data_spec.microbatch_transforms(self.state.batch)

# Count the batch size and num tokens before any events run
rank_num_samples = data_spec.get_num_samples_in_batch(self.state.batch)
Expand Down Expand Up @@ -3586,7 +3588,7 @@ def _eval_loop(
)

for self.state.batch in self._iter_dataloader(TrainerMode.EVAL):
self.state.batch = data_spec.device_transforms(self.state.batch)
self.state.batch = data_spec.batch_transforms(self.state.batch)

# Count the batch size and num tokens before any events run
rank_num_samples = data_spec.get_num_samples_in_batch(self.state.batch)
Expand Down Expand Up @@ -3616,6 +3618,7 @@ def _eval_loop(
microbatches = data_spec.split_batch(device_batch, evaluator.device_eval_microbatch_size)
for i, self.state.batch in enumerate(microbatches):
self.state.batch = self.state.device.batch_to_device(self.state.batch)
self.state.batch = data_spec.microbatch_transforms(self.state.batch)
last_microbatch = i == len(microbatches) - 1
skip_metric_update = False
# Distributed samplers pad batches to be the same size. If using a
Expand Down
Loading