Skip to content

Commit 8f683f0

Browse files
authored
Merge pull request #19 from LambdaLabsML/llama-405b
[WIP] Llama 405b
2 parents 69e36c0 + d636f78 commit 8f683f0

File tree

25 files changed

+767
-36
lines changed

25 files changed

+767
-36
lines changed

01-single-gpu/README.md

+3-5
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,9 @@ This is a basic train from scratch script. Here are some quick facts:
2323
device = torch.device("cuda")
2424
dtype = torch.bfloat16
2525

26-
config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True)
27-
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True).to(
28-
dtype=dtype, device=device
29-
)
30-
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
26+
config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
27+
model = AutoModelForCausalLM.from_config(config).to(dtype=dtype, device=device)
28+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
3129
```
3230

3331
2. We save checkpoints into `args.save_dir/args.experiment_name` - `--experiment-name is a **unique** run identifier

01-single-gpu/train_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def main():
4242
def _load_to_device(p):
4343
return torch.load(p, map_location=device, weights_only=True)
4444

45-
config = AutoConfig.from_pretrained(args.model_name)
45+
config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
4646
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype).to(device)
4747
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
4848

02-multi-gpu/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,11 @@ def rank0_first():
186186
```
187187

188188
```diff
189-
- config = AutoConfig.from_pretrained(args.model_name)
189+
- config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
190190
- model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype).to(device)
191191
- tokenizer = AutoTokenizer.from_pretrained(args.model_name)
192192
+ with rank0_first():
193-
+ config = AutoConfig.from_pretrained(args.model_name)
193+
+ config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
194194
+ model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype).to(device)
195195
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
196196
```

02-multi-gpu/train_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _load_to_device(p):
5959
return torch.load(p, map_location=device, weights_only=True)
6060

6161
with rank0_first():
62-
config = AutoConfig.from_pretrained(args.model_name)
62+
config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
6363
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype).to(device)
6464
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
6565

03-multi-node/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ index 3130381..d5cb05c 100644
5454
- assert world_size == torch.cuda.device_count()
5555

5656
- _LOGGER.info(f"rank={rank} world size={world_size}")
57-
+ _LOGGER.info(f"local rank={local_rank} rank={rank} world size={world_size}")
57+
+ _LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
5858

5959
- device = torch.device(f"cuda:{rank}")
6060
+ device = torch.device(f"cuda:{local_rank}")

03-multi-node/train_llm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def main():
4949
local_rank = rank % torch.cuda.device_count()
5050
world_size = dist.get_world_size()
5151

52-
_LOGGER.info(f"local rank={local_rank} rank={rank} world size={world_size}")
52+
_LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
5353

5454
device = torch.device(f"cuda:{local_rank}")
5555
dtype = torch.bfloat16
@@ -59,7 +59,7 @@ def _load_to_device(p):
5959
return torch.load(p, map_location=device, weights_only=True)
6060

6161
with rank0_first():
62-
config = AutoConfig.from_pretrained(args.model_name)
62+
config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
6363
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype).to(device)
6464
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
6565

04-job-launchers-bash/train_llm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def main():
4949
local_rank = rank % torch.cuda.device_count()
5050
world_size = dist.get_world_size()
5151

52-
_LOGGER.info(f"local rank={local_rank} rank={rank} world size={world_size}")
52+
_LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
5353

5454
device = torch.device(f"cuda:{local_rank}")
5555
dtype = torch.bfloat16
@@ -59,7 +59,7 @@ def _load_to_device(p):
5959
return torch.load(p, map_location=device, weights_only=True)
6060

6161
with rank0_first():
62-
config = AutoConfig.from_pretrained(args.model_name)
62+
config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
6363
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype).to(device)
6464
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
6565

04-job-launchers-deepspeed/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ index ae1c66f..d5671b3 100644
2626
+ local_rank = rank % torch.cuda.device_count()
2727
world_size = dist.get_world_size()
2828

29-
_LOGGER.info(f"local rank={local_rank} rank={rank} world size={world_size}")
29+
_LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
3030
@@ -306,6 +309,7 @@ def _get_parser() -> argparse.ArgumentParser:
3131
parser.add_argument("--log-freq", default=100, type=int)
3232
parser.add_argument("--ckpt-freq", default=500, type=int)

04-job-launchers-deepspeed/train_llm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def main():
5252
local_rank = rank % torch.cuda.device_count()
5353
world_size = dist.get_world_size()
5454

55-
_LOGGER.info(f"local rank={local_rank} rank={rank} world size={world_size}")
55+
_LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
5656

5757
device = torch.device(f"cuda:{local_rank}")
5858
dtype = torch.bfloat16
@@ -62,7 +62,7 @@ def _load_to_device(p):
6262
return torch.load(p, map_location=device, weights_only=True)
6363

6464
with rank0_first():
65-
config = AutoConfig.from_pretrained(args.model_name)
65+
config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
6666
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype).to(device)
6767
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
6868

04-job-launchers-mpirun/train_llm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def main():
5454
local_rank = rank % torch.cuda.device_count()
5555
world_size = dist.get_world_size()
5656

57-
_LOGGER.info(f"local rank={local_rank} rank={rank} world size={world_size}")
57+
_LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
5858

5959
device = torch.device(f"cuda:{local_rank}")
6060
dtype = torch.bfloat16
@@ -64,7 +64,7 @@ def _load_to_device(p):
6464
return torch.load(p, map_location=device, weights_only=True)
6565

6666
with rank0_first():
67-
config = AutoConfig.from_pretrained(args.model_name)
67+
config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
6868
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype).to(device)
6969
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
7070

04-job-launchers-slurm/train_llm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def main():
4949
local_rank = rank % torch.cuda.device_count()
5050
world_size = dist.get_world_size()
5151

52-
_LOGGER.info(f"local rank={local_rank} rank={rank} world size={world_size}")
52+
_LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
5353

5454
device = torch.device(f"cuda:{local_rank}")
5555
dtype = torch.bfloat16
@@ -59,7 +59,7 @@ def _load_to_device(p):
5959
return torch.load(p, map_location=device, weights_only=True)
6060

6161
with rank0_first():
62-
config = AutoConfig.from_pretrained(args.model_name)
62+
config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
6363
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype).to(device)
6464
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
6565

05-sharding-deepspeed/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ Two main differences here:
8080
+ local_rank = rank % torch.cuda.device_count()
8181
world_size = dist.get_world_size()
8282

83-
_LOGGER.info(f"local rank={local_rank} rank={rank} world size={world_size}")
83+
_LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
8484

8585
@@ -73,10 +73,6 @@ def main():
8686
if len(tokenizer) > embedding_size:

05-sharding-deepspeed/train_llm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ def main():
5353
local_rank = rank % torch.cuda.device_count()
5454
world_size = dist.get_world_size()
5555

56-
_LOGGER.info(f"local rank={local_rank} rank={rank} world size={world_size}")
56+
_LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
5757

5858
device = torch.device(f"cuda:{local_rank}")
5959
dtype = torch.bfloat16
6060
torch.cuda.set_device(device)
6161

6262
with rank0_first():
63-
config = AutoConfig.from_pretrained(args.model_name)
63+
config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
6464
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype).to(device)
6565
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
6666

05-sharding-fsdp/train_llm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def main():
6464
local_rank = rank % torch.cuda.device_count()
6565
world_size = dist.get_world_size()
6666

67-
_LOGGER.info(f"local rank={local_rank} rank={rank} world size={world_size}")
67+
_LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
6868

6969
device = torch.device(f"cuda:{local_rank}")
7070
dtype = torch.bfloat16
@@ -74,7 +74,7 @@ def _load_to_device(p):
7474
return torch.load(p, map_location=device, weights_only=True)
7575

7676
with rank0_first():
77-
config = AutoConfig.from_pretrained(args.model_name)
77+
config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
7878
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
7979
# NOTE: meta device will not allocate any memory
8080
with torch.device("meta"):

10-training-llama-405b/README.md

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Training a 405B model
2+
3+
Here we are going to utilize a huge cluster to train Llama 3.1 405B. **This does not utilize LORA!** We are actually fully training the weights of a 405b model in plain pytorch.
4+
5+
## Use flash attention
6+
7+
```bash
8+
pip install packaging
9+
pip install ninja
10+
pip install flash-attn --no-build-isolation
11+
```
12+
13+
[Source](https://github.com/Dao-AILab/flash-attention)
14+
15+
```python
16+
model = AutoModelForCausalLM.from_pretrained(
17+
...
18+
attn_implementation="flash_attention_2",
19+
)
20+
```
21+
22+
## Download model weights
23+
24+
There are two options here:
25+
26+
1. A shared network drive
27+
2. Locally on each node
28+
29+
Node local storage is **vastly** faster. For some numbers, while running this script on 8 8xH100 80GB nodes, the shared network drive took 50 minutes to initialize, while the node local storage only took 3 minutes.
30+
31+
There's a download script in this repo for utility, run this on node 0:
32+
33+
```bash
34+
cd distributed-training-guide/10-training-llama-405b
35+
python download.py
36+
```
37+
38+
## Loading pretrained weights
39+
40+
There's three parts to this:
41+
42+
1. Using device_map "cpu" for rank 0, and meta device for rank > 0
43+
2. Using from_config instead of from_pretrained on rank > 0
44+
3. FSDP.sync_module_states=True
45+
46+
We can't actually use device_map "auto", because this will fully utilize the rank 0 gpu. When we try to initialize FSDP later we won't have any memory left to allocate. Instead we use device_map="cpu" on rank 0:
47+
48+
```python
49+
if rank == 0:
50+
with torch.device("cpu"):
51+
model = AutoModelForCausalLM.from_pretrained(...)
52+
else:
53+
with torch.device("meta"):
54+
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
55+
```
56+
57+
Then later, sync_module_states in FSDP constructor will make sure the weights are broadcasted from rank 0 to the other ranks.
58+
59+
## Sharding Llama 405B
60+
61+
Most of the tutorials on training Llama 405b just shard the `LlamaDecoderLayer` (there's 191 of them). However during testing I also found that sharding the `nn.Embedding` layer at the beginning of the network improved throughput and reduced memory usage. We can use the `transformer_auto_wrap_policy` to target the specific classes for those layers:
62+
63+
```python
64+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
65+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
66+
67+
wrap_policy = functools.partial(
68+
transformer_auto_wrap_policy,
69+
transformer_layer_cls={LlamaDecoderLayer, nn.Embedding},
70+
)
71+
FSDP(..., auto_wrap_policy=wrap_policy)
72+
```
73+
74+
## Gradient checkpointing
75+
76+
Another piece of this is gradient checkpointing, which saves a lot of memory. This piece of code has to go **after** the FSDP constructor!!! I'm not exactly sure of the reason, but it doesn't work before the FSDP initialization.
77+
78+
The method we are using is kind of a hidden method in pytorch, but this is actually exactly what [accelerate uses under the hood](https://github.com/huggingface/accelerate/blob/v0.34.2/src/accelerate/accelerator.py#L1492) so rest assured that it is a "standard" way of doing it:
79+
80+
```python
81+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
82+
apply_activation_checkpointing,
83+
checkpoint_wrapper,
84+
)
85+
86+
apply_activation_checkpointing(
87+
model, checkpoint_wrapper_fn=checkpoint_wrapper, auto_wrap_policy=wrap_policy
88+
)
89+
```
90+
91+
## fused Optimizer implementation
92+
93+
When using CPUOffload feature of FSDP, the optimizer entirely runs on the CPU. This is because there is significant cost to transfer data to and from the GPU when doing optimizer.step(). At the time of this being written there are open issues on how to overlap the optimizer.step() with the next forward() call.
94+
95+
By default the optimizers will use a per tensor forward call on the cpu, but there are flags you can enable to get a bit of a speedup:
96+
97+
```python
98+
torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True)
99+
```
100+
101+
## zero_grad(set_to_none=???)
102+
103+
You may have seen this set_to_none argument in [optimizer.zero_grad()](https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html). According to the docs:
104+
105+
> This will in general have lower memory footprint, and can modestly improve performance.
106+
107+
Basically set_to_none=True will deallocate the gradients after they are used. In cases where we aren't memory constrained, keeping the gradients around (and reducing the amout of allocations) is a good thing for performance. However when we are memory constrained, setting to none gives us more memory to use.
108+
109+
When we are using CPUOffload though, the memory we are keeping is just on the CPU. So there isn't really a GPU memory cost to keeping them around!
110+
111+
```python
112+
optimizer.zero_grad(set_to_none=args.cpu_offload == "off")
113+
```
114+
115+
## Launch command
116+
117+
We provide a customized launch.sh script here based on the bash command for spawning torchrun on all available nodes:
118+
119+
```bash
120+
cd distributed-training-guide/10-training-llama-405b
121+
bash launch.sh # NOTE: this is non blocking
122+
```
123+
124+
Also note that this launch.sh specifies `HF_HOME` as an environment variable in the tmux session, so if you've not used the default value of `/home/ubuntu/.cache/huggingface`, please update the script!
125+
126+
You can change the hostnames in the `hosts` file in this directory.
127+
128+
## Monitoring
129+
130+
The log files are really useful for monitoring the progress of everything. Here's a bash command for tailing all of them at once:
131+
132+
```bash
133+
cd distributed-training-guide/10-training-llama-405b
134+
find ../logs/ -name \*stderr.log | xargs tail -f
135+
```
136+
137+
Additionally, we have a top like utility script for monitoring the entire cluster at the top level of this directory:
138+
139+
```bash
140+
cd distributed-training-guide/10-training-llama-405b
141+
python ../top-cluster.py hosts
142+
```
143+
144+
## Run statistics
145+
146+
Running this with 64 H100 gpus (8 separate nodes) has the following stats:
147+
148+
- ~30s per iteration (data/forward/backward/update). Breakdown is
149+
- data: ~2ms
150+
- forward: ~7s
151+
- backward: ~19s
152+
- update: ~4s
153+
- Peak Memory Allocated: 52.9GB
154+
- Peak Memory Reserved: 77.9GB
155+
156+
Noting that reserved memory has to do with pytorch allocation caching.
157+
158+
## Other notes on settings that didn't affect throughput
159+
160+
- Allowing tf32 had no impact on throughput (`torch.backends.cudnn.allow_tf32` and `torch.backends.cuda.matmul.allow_tf32`)
161+
- Enabling benchmarking had no impact on throughput (`torch.backends.cudnn.benchmark = True`)
162+
- Using CuDNN sdpa was slower (`attn_implementation="sdpa"` and `torch.backends.cuda.enable_cudnn_sdp(True)`)
163+
- torch.compile had no impact (`use_orig_params=True` and `torch.compile` after FSDP constructor)
164+
- Very minimal testing of NCCL environment variables either made things worse or had no impact (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html)
165+
- `PYTORCH_NO_CUDA_MEMORY_CACHING=1` made enough memory available that `--batch-size 2` or higher sequence lengths were possible, but it was much much slower.
166+
- It's possible that some well placed calls to `torch.cuda.empty_cache()` could achieve this without the throughput loss.
167+
- Only `FULL_SHARD` works. Others fail silently.

10-training-llama-405b/download.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import os
2+
import torch
3+
import transformers
4+
5+
os.environ["HF_HOME"] = "/home/ubuntu/.cache/huggingface"
6+
7+
model_name = "meta-llama/Meta-Llama-3.1-405B"
8+
9+
print(f"Downloading {model_name} to $HF_HOME = {os.environ['HF_HOME']}.")
10+
11+
config = transformers.AutoConfig.from_pretrained(model_name)
12+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
13+
with torch.device("meta"):
14+
model = transformers.AutoModelForCausalLM.from_pretrained(model_name)

10-training-llama-405b/hosts

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
ml-64-node-001
2+
ml-64-node-002
3+
ml-64-node-003
4+
ml-64-node-004
5+
ml-64-node-005
6+
ml-64-node-006
7+
ml-64-node-007
8+
ml-64-node-008

0 commit comments

Comments
 (0)