forked from rawsh/mirrorllm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
modal_train_prm_st.py
66 lines (59 loc) · 2.17 KB
/
modal_train_prm_st.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import modal
cuda_version = "12.4.0" # should be no greater than host CUDA version
flavor = "devel" # includes full CUDA toolkit
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"
image = (
# modal.Image.debian_slim()
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11")
.apt_install("git")
.pip_install("torch")
.pip_install("packaging")
.pip_install("wheel")
.run_commands("pip install flash-attn --no-build-isolation")
.pip_install("transformers")
.pip_install("accelerate")
.pip_install("numpy")
.pip_install("datasets")
.pip_install("wandb")
.pip_install("bitsandbytes")
.pip_install("matplotlib")
.pip_install("seaborn")
)
app = modal.App("train_prm", image=image)
with image.imports():
from mcts.train_reward import train_reward_model
MINUTES = 60 # seconds
HOURS = 60 * MINUTES
vol = modal.Volume.from_name("prm-tmp", create_if_missing=True)
@app.function(
cpu=2.0,
# gpu=modal.gpu.A10G(),
gpu=modal.gpu.H100(),
# gpu=modal.gpu.A100(count=4, size="40GB"),
# gpu=modal.gpu.A100(size="40GB"),
timeout=20 * HOURS,
secrets=[
modal.Secret.from_name("hf-token"),
modal.Secret.from_name("wandb-token")
],
volumes={"/out": vol},
)
def train_reward_model_upload_to_hf():
train_reward_model(
# add revision
model_name="rawsh/mirrorqwen2.5-0.5b-prm",
# model_revision="aed1bcf7d3d984272e329c3843f9c5fd0dfe5ca5", # base
# model_revision="42e07d1b708282ac2aae338050d8116f8c69398d", # st0
# model_revision="80da7ccc4f107e0cb6bf937d61be4702badfb96b", # st1
# model_revision="4d618515c90069993f4b32e4201783efdeebbc22", # st2
# fucked up orpo2 prm - it used st0 as base model as well.
model_revision="e49e4ca7c847194be48c42c52ad8f871da204300", # orpo2
dataset_path="rawsh/mirrorqwen2.5-0.5B-gsm8k-PRM-data-ORPO-2",
output_model_name="rawsh/mirrorqwen2.5-0.5b-prm",
disable_binning=False
)
@app.local_entrypoint()
def main():
# run the function remotely on Modal
train_reward_model_upload_to_hf.remote()