-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
46 lines (37 loc) · 1.35 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from dataclasses import dataclass, field, asdict
from typing import Optional
import transformers
import os
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
from data.cptdata import get_task_data_module
@dataclass
class TrainingConfig:
task_name: str
block_size: int
rehersal_rate: float
model_name: str
subsample_ratio: float
wandb_project: Optional[str] = field(default="synthetic-continued-pretraining")
def __post_init__(self):
os.environ['WANDB_PROJECT'] = self.wandb_project
def train():
# parsing input
parser = transformers.HfArgumentParser((TrainingConfig, transformers.TrainingArguments))
config, args = parser.parse_args_into_dataclasses()
log_config = {**asdict(config), **asdict(args)}
logging.info(f"Training config: {log_config}")
# loading model
model = transformers.AutoModelForCausalLM.from_pretrained(
config.model_name)
# loading dataset
data_module = get_task_data_module(**asdict(config))
# setting up trainer
trainer = transformers.Trainer(model=model, args=args, **data_module)
trainer.train()
trainer.save_model(output_dir=args.output_dir)
trainer.accelerator.wait_for_everyone()
if __name__ == "__main__":
train()