-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
executable file
·130 lines (108 loc) · 3.78 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import logging
from dataclasses import dataclass, field
from typing import Optional
import transformers
from peft import LoraConfig, get_peft_model
from transformers import (
Trainer,
TrainingArguments,
AutoTokenizer,
PreTrainedTokenizer
)
from src.dataset.collator import InstructionTuningCollator
from src.dataset.dataset import InstructionTuningDataset
from src.model.init_llemr import init_llemr
from src.model.modeling_llemr import LlemrForConditionalGeneration
from src.model.utils import find_all_linear_names
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
name_or_path: Optional[str] = field(default=None)
llm_pretrained_model_name_or_path: Optional[str] = field(default="Qwen/Qwen2-0.5B-Instruct")
train_type: Optional[str] = field(
default="train_both",
metadata={
"help": """
1. train_multi_modal_projector
2. train_both
"""
},
)
use_lora: Optional[bool] = field(default=True)
lora_r: int = 32
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_bias: str = "none"
vision_hidden_size: int = 768
@dataclass
class DataArguments:
source: Optional[str] = field(default="note")
def load_model(model_args: ModelArguments):
if model_args.name_or_path is not None:
logging.warning(f"Load model {model_args.name_or_path} from pretrained")
model = LlemrForConditionalGeneration.from_pretrained(
model_args.name_or_path
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.name_or_path,
padding_side="left"
)
else:
logging.warning(f"Init model {model_args.llm_pretrained_model_name_or_path}")
model, tokenizer = init_llemr(
model_args.llm_pretrained_model_name_or_path, model_args.vision_hidden_size
)
assert model_args.train_type in ["train_multi_modal_projector", "train_both"]
if model_args.train_type == "train_multi_modal_projector":
logging.warning("Train multi_modal_projector")
for param in model.language_model.parameters():
param.requires_grad = False
else:
logging.warning("Train both")
if model_args.use_lora:
assert model_args.train_type == "train_both"
logging.warning("Use Lora")
config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
target_modules=find_all_linear_names(model),
lora_dropout=model_args.lora_dropout,
bias=model_args.lora_bias,
task_type="CAUSAL_LM",
modules_to_save=["multi_modal_projector"],
)
model = get_peft_model(model, config)
else:
logging.warning("Not use Lora")
return model, tokenizer
def load_data(data_args: DataArguments, tokenizer: PreTrainedTokenizer):
train_dataset = InstructionTuningDataset(
split="train",
source=data_args.source,
)
val_dataset = InstructionTuningDataset(
split="val",
source=data_args.source,
)
collator = InstructionTuningCollator(
tokenizer=tokenizer,
)
return train_dataset, val_dataset, collator
def train():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model, tokenizer = load_model(model_args)
train_dataset, val_dataset, collator = load_data(data_args, tokenizer)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=collator,
)
tokenizer.save_pretrained(training_args.output_dir)
trainer.train()
if __name__ == "__main__":
train()