This homework consists of two parts:
- Part 1:
./practice_part1.ipynb
- memory-efficient training and inference
- Part 2:
./practice_part2.ipynb
- implementing model and sequence parallelism
Part 1 will require you to implement memory-saving techniques such as offloading and gradient checkpointing / accumulation. To implement offloading, you may either write your own low-level code, or use the recommended trick: write your own autograd.Function (similar to gradient checkpoint function) that moves the requisite modules on device just in time for computation. Our practice video ('25) contains some tips on extending autograd functions, but those are optional.
Part 2 is much more convenient with multiple GPUs - though, it can potentially be solved by emulating GPUs with CPU-only code. For YSDA and HSE students, you can use either DataSphere or one of the GPU servers available for this course (recommended). If you are an online student, you can try to register for kaggle kernels (they ley you run on 2x T4) in jupyter-like interface. That said, implementing assignments B and C in Kaggle is more difficult than intended. For non-enrolled online students, we recommend option A unless you have access to some other multi-GPU-hardware or are intentionally masochistic.
- PyTorch gradient checkpointing - API reference
- PyTorch native ZeRO - FullyShardedDataParallel
- GPipe (one good implementation of pipelining) - arxiv
- Megatron-LM - one honking great implementation of large-scale training for transformers - repo
- DeepSpeed (a library of many tricks) - repo
- Alpa (automated parallelism in Jax - https://github.com/alpa-projects/alpa
- ICML'22 tutorial: https://sites.google.com/view/icml-2022-big-model
- FairScale - sharded DDP and pipeline from Meta - repo
tensor_parallel
- automated tensor parallelism in PyTorch
During the in-class practice, we also had several PyTorch code examples that could come in handy when training large models:
Gradient checkpointing:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
class Checkpoint(nn.Sequential):
def forward(self, *inputs):
return checkpoint(super().forward, *inputs)
class Echo(nn.Module):
def __init__(self, msg: str):
super().__init__()
self.msg = msg # print this message during forward (for debugging)
def forward(self, x):
print("forward", self.msg)
return x
model = nn.Sequential(
Checkpoint(nn.Linear(1000, 1000), nn.ReLU(), Echo("layer1 done"),
nn.Linear(1000, 1000), nn.ReLU(), Echo("layer2 done")),
Checkpoint(nn.Linear(1000, 1000), nn.ReLU(), Echo("layer3 done"),
nn.Linear(1000, 1000), nn.ReLU(), Echo("layer4 done")),
nn.Linear(1000, 1000), nn.ReLU(), Echo("layer5 done"),
)
inputs = torch.randn(16, 1000, requires_grad=True)
# note: we must set inptus requires_grad=True because checkpoints require at least one input with grad for backprop
outputs = model(inputs)
outputs.norm().backward() # Echo layers will print in the following order: 1 2 3 4 5 3 4 1 2