|
| 1 | +# Training a 405B model |
| 2 | + |
| 3 | +Here we are going to utilize a huge cluster to train Llama 3.1 405B. **This does not utilize LORA!** We are actually fully training the weights of a 405b model in plain pytorch. |
| 4 | + |
| 5 | +## Use flash attention |
| 6 | + |
| 7 | +```bash |
| 8 | +pip install packaging |
| 9 | +pip install ninja |
| 10 | +pip install flash-attn --no-build-isolation |
| 11 | +``` |
| 12 | + |
| 13 | +[Source](https://github.com/Dao-AILab/flash-attention) |
| 14 | + |
| 15 | +```python |
| 16 | +model = AutoModelForCausalLM.from_pretrained( |
| 17 | + ... |
| 18 | + attn_implementation="flash_attention_2", |
| 19 | +) |
| 20 | +``` |
| 21 | + |
| 22 | +## Download model weights |
| 23 | + |
| 24 | +There are two options here: |
| 25 | + |
| 26 | +1. A shared network drive |
| 27 | +2. Locally on each node |
| 28 | + |
| 29 | +Node local storage is **vastly** faster. For some numbers, while running this script on 8 8xH100 80GB nodes, the shared network drive took 50 minutes to initialize, while the node local storage only took 3 minutes. |
| 30 | + |
| 31 | +There's a download script in this repo for utility, run this on node 0: |
| 32 | + |
| 33 | +```bash |
| 34 | +cd distributed-training-guide/10-training-llama-405b |
| 35 | +python download.py |
| 36 | +``` |
| 37 | + |
| 38 | +## Loading pretrained weights |
| 39 | + |
| 40 | +There's three parts to this: |
| 41 | + |
| 42 | +1. Using device_map "cpu" for rank 0, and meta device for rank > 0 |
| 43 | +2. Using from_config instead of from_pretrained on rank > 0 |
| 44 | +3. FSDP.sync_module_states=True |
| 45 | + |
| 46 | +We can't actually use device_map "auto", because this will fully utilize the rank 0 gpu. When we try to initialize FSDP later we won't have any memory left to allocate. Instead we use device_map="cpu" on rank 0: |
| 47 | + |
| 48 | +```python |
| 49 | +if rank == 0: |
| 50 | + with torch.device("cpu"): |
| 51 | + model = AutoModelForCausalLM.from_pretrained(...) |
| 52 | +else: |
| 53 | + with torch.device("meta"): |
| 54 | + model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) |
| 55 | +``` |
| 56 | + |
| 57 | +Then later, sync_module_states in FSDP constructor will make sure the weights are broadcasted from rank 0 to the other ranks. |
| 58 | + |
| 59 | +## Sharding Llama 405B |
| 60 | + |
| 61 | +Most of the tutorials on training Llama 405b just shard the `LlamaDecoderLayer` (there's 191 of them). However during testing I also found that sharding the `nn.Embedding` layer at the beginning of the network improved throughput and reduced memory usage. We can use the `transformer_auto_wrap_policy` to target the specific classes for those layers: |
| 62 | + |
| 63 | +```python |
| 64 | +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
| 65 | +from transformers.models.llama.modeling_llama import LlamaDecoderLayer |
| 66 | + |
| 67 | +wrap_policy = functools.partial( |
| 68 | + transformer_auto_wrap_policy, |
| 69 | + transformer_layer_cls={LlamaDecoderLayer, nn.Embedding}, |
| 70 | +) |
| 71 | +FSDP(..., auto_wrap_policy=wrap_policy) |
| 72 | +``` |
| 73 | + |
| 74 | +## Gradient checkpointing |
| 75 | + |
| 76 | +Another piece of this is gradient checkpointing, which saves a lot of memory. This piece of code has to go **after** the FSDP constructor!!! I'm not exactly sure of the reason, but it doesn't work before the FSDP initialization. |
| 77 | + |
| 78 | +The method we are using is kind of a hidden method in pytorch, but this is actually exactly what [accelerate uses under the hood](https://github.com/huggingface/accelerate/blob/v0.34.2/src/accelerate/accelerator.py#L1492) so rest assured that it is a "standard" way of doing it: |
| 79 | + |
| 80 | +```python |
| 81 | +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
| 82 | + apply_activation_checkpointing, |
| 83 | + checkpoint_wrapper, |
| 84 | +) |
| 85 | + |
| 86 | +apply_activation_checkpointing( |
| 87 | + model, checkpoint_wrapper_fn=checkpoint_wrapper, auto_wrap_policy=wrap_policy |
| 88 | +) |
| 89 | +``` |
| 90 | + |
| 91 | +## fused Optimizer implementation |
| 92 | + |
| 93 | +When using CPUOffload feature of FSDP, the optimizer entirely runs on the CPU. This is because there is significant cost to transfer data to and from the GPU when doing optimizer.step(). At the time of this being written there are open issues on how to overlap the optimizer.step() with the next forward() call. |
| 94 | + |
| 95 | +By default the optimizers will use a per tensor forward call on the cpu, but there are flags you can enable to get a bit of a speedup: |
| 96 | + |
| 97 | +```python |
| 98 | +torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True) |
| 99 | +``` |
| 100 | + |
| 101 | +## zero_grad(set_to_none=???) |
| 102 | + |
| 103 | +You may have seen this set_to_none argument in [optimizer.zero_grad()](https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html). According to the docs: |
| 104 | + |
| 105 | +> This will in general have lower memory footprint, and can modestly improve performance. |
| 106 | +
|
| 107 | +Basically set_to_none=True will deallocate the gradients after they are used. In cases where we aren't memory constrained, keeping the gradients around (and reducing the amout of allocations) is a good thing for performance. However when we are memory constrained, setting to none gives us more memory to use. |
| 108 | + |
| 109 | +When we are using CPUOffload though, the memory we are keeping is just on the CPU. So there isn't really a GPU memory cost to keeping them around! |
| 110 | + |
| 111 | +```python |
| 112 | +optimizer.zero_grad(set_to_none=args.cpu_offload == "off") |
| 113 | +``` |
| 114 | + |
| 115 | +## Launch command |
| 116 | + |
| 117 | +We provide a customized launch.sh script here based on the bash command for spawning torchrun on all available nodes: |
| 118 | + |
| 119 | +```bash |
| 120 | +cd distributed-training-guide/10-training-llama-405b |
| 121 | +bash launch.sh # NOTE: this is non blocking |
| 122 | +``` |
| 123 | + |
| 124 | +Also note that this launch.sh specifies `HF_HOME` as an environment variable in the tmux session, so if you've not used the default value of `/home/ubuntu/.cache/huggingface`, please update the script! |
| 125 | + |
| 126 | +You can change the hostnames in the `hosts` file in this directory. |
| 127 | + |
| 128 | +## Monitoring |
| 129 | + |
| 130 | +The log files are really useful for monitoring the progress of everything. Here's a bash command for tailing all of them at once: |
| 131 | + |
| 132 | +```bash |
| 133 | +cd distributed-training-guide/10-training-llama-405b |
| 134 | +find ../logs/ -name \*stderr.log | xargs tail -f |
| 135 | +``` |
| 136 | + |
| 137 | +Additionally, we have a top like utility script for monitoring the entire cluster at the top level of this directory: |
| 138 | + |
| 139 | +```bash |
| 140 | +cd distributed-training-guide/10-training-llama-405b |
| 141 | +python ../top-cluster.py hosts |
| 142 | +``` |
| 143 | + |
| 144 | +## Run statistics |
| 145 | + |
| 146 | +Running this with 64 H100 gpus (8 separate nodes) has the following stats: |
| 147 | + |
| 148 | +- ~30s per iteration (data/forward/backward/update). Breakdown is |
| 149 | + - data: ~2ms |
| 150 | + - forward: ~7s |
| 151 | + - backward: ~19s |
| 152 | + - update: ~4s |
| 153 | +- Peak Memory Allocated: 52.9GB |
| 154 | +- Peak Memory Reserved: 77.9GB |
| 155 | + |
| 156 | +Noting that reserved memory has to do with pytorch allocation caching. |
| 157 | + |
| 158 | +## Other notes on settings that didn't affect throughput |
| 159 | + |
| 160 | +- Allowing tf32 had no impact on throughput (`torch.backends.cudnn.allow_tf32` and `torch.backends.cuda.matmul.allow_tf32`) |
| 161 | +- Enabling benchmarking had no impact on throughput (`torch.backends.cudnn.benchmark = True`) |
| 162 | +- Using CuDNN sdpa was slower (`attn_implementation="sdpa"` and `torch.backends.cuda.enable_cudnn_sdp(True)`) |
| 163 | +- torch.compile had no impact (`use_orig_params=True` and `torch.compile` after FSDP constructor) |
| 164 | +- Very minimal testing of NCCL environment variables either made things worse or had no impact (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html) |
| 165 | +- `PYTORCH_NO_CUDA_MEMORY_CACHING=1` made enough memory available that `--batch-size 2` or higher sequence lengths were possible, but it was much much slower. |
| 166 | + - It's possible that some well placed calls to `torch.cuda.empty_cache()` could achieve this without the throughput loss. |
| 167 | +- Only `FULL_SHARD` works. Others fail silently. |
0 commit comments