-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmodal_train_prm_rlhf_flow.py
68 lines (58 loc) · 1.63 KB
/
modal_train_prm_rlhf_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# train.py
import modal
import yaml
import os
from pathlib import Path
# CUDA setup
AXOLOTL_REGISTRY_SHA = "9578c47333bdcc9ad7318e54506b9adaf283161092ae780353d506f7a656590a"
image = (
modal.Image.from_registry(f"winglian/axolotl@sha256:{AXOLOTL_REGISTRY_SHA}")
.pip_install(
"huggingface_hub==0.23.2",
"hf-transfer==0.1.5",
"wandb==0.16.3",
"fastapi==0.110.0",
"pydantic==2.6.3",
)
.env(
dict(
HUGGINGFACE_HUB_CACHE="/pretrained",
HF_HUB_ENABLE_HF_TRANSFER="1",
AXOLOTL_NCCL_TIMEOUT="60",
)
)
.entrypoint([])
)
app = modal.App("train-hf", image=image)
# Constants
MINUTES = 60
HOURS = 60 * MINUTES
# Create volume for persistent storage
training_vol = modal.Volume.from_name("training-data", create_if_missing=True)
@app.function(
cpu=8,
gpu=modal.gpu.H100(),
timeout=20 * HOURS,
volumes={"/training": training_vol},
secrets=[
modal.Secret.from_name("hf-token"),
modal.Secret.from_name("wandb-token")
],
)
def run_training(config):
import subprocess
# Write the config to the container
config_path = Path("/training/config.yml")
with open(config_path, 'w') as f:
yaml.dump(config, f)
# Run training - Axolotl will handle HF upload if push_to_hub is True
subprocess.run(["python", "-m", "axolotl.cli.train", config_path])
@app.local_entrypoint()
def main():
# Read the local config file
with open("prm_rlhf_flow/qwen.yml", 'r') as f:
config = yaml.safe_load(f)
# Run the training
run_training.remote(config)
if __name__ == "__main__":
main()