-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_config.py
84 lines (73 loc) · 2.25 KB
/
train_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import abc
import dataclasses
@dataclasses.dataclass
class BaseConfig(abc.ABC):
def to_config_string(self) -> str:
config_str = ""
for field in dataclasses.fields(self):
value = getattr(self, field.name)
if isinstance(value, bool) and value:
config_str += f"--{field.name.replace('_', '-')} "
elif not isinstance(value, bool):
config_str += f"--{field.name.replace('_', '-')} {value} "
return config_str.strip()
@dataclasses.dataclass
class DistributedConfig(BaseConfig):
nproc_per_node: int
nnodes: int
node_rank: int
master_addr: str
master_port: str
@dataclasses.dataclass
class DatasetConfig(BaseConfig):
vocab_file: str
merge_file: str
# data_path: str
split: str
mock_data: bool
@dataclasses.dataclass
class ModelConfig(BaseConfig):
tensor_model_parallel_size: int
pipeline_model_parallel_size: int
num_layers: int
hidden_size: int
num_attention_heads: int
seq_length: int
max_position_embeddings: int
micro_batch_size: int
global_batch_size: int
lr: float
train_iters: int
lr_decay_iters: int
lr_decay_style: str
min_lr: float
weight_decay: float
lr_warmup_fraction: str
clip_grad: float
fp16: bool
loss_scale: float
failslow_aware: bool
@dataclasses.dataclass
class TrainConfig(BaseConfig):
distributed_config: DistributedConfig
dataset_config: DatasetConfig
model_config: ModelConfig
log_interval: int
save_interval: int
eval_interval: int
eval_iters: int
distributed_backend: str
save: str # checkpoint save path
load: str # checkpoint load path
def to_config_string(self) -> str:
other_config_str = ""
for field in dataclasses.fields(self):
value = getattr(self, field.name)
if not isinstance(value, BaseConfig):
other_config_str += f"--{field.name.replace('_', '-')} {value} "
return "torchrun {} pretrain_gpt.py {} {} --sequence-parallel {}".format(
self.distributed_config.to_config_string(),
self.model_config.to_config_string(),
self.dataset_config.to_config_string(),
other_config_str
).strip()