This repository has been archived by the owner on Aug 16, 2024. It is now read-only.
forked from TUDB-Labs/mLoRA
-
Notifications
You must be signed in to change notification settings - Fork 7
/
mlora.py
297 lines (260 loc) · 9.2 KB
/
mlora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
import argparse
import json
import logging
import os
import sys
from typing import Dict, List, Tuple, Union
import torch
from transformers.utils import is_flash_attn_2_available
import mlora
# Command Line Arguments
parser = argparse.ArgumentParser(description="m-LoRA main program")
parser.add_argument(
"--base_model", type=str, required=True, help="Path to or name of base model"
)
parser.add_argument(
"--inference", action="store_true", help="The inference mode (just for test)"
)
parser.add_argument(
"--evaluate", action="store_true", help="The evaluate mode (just for test)"
)
parser.add_argument(
"--disable_prompter", action="store_true", help="Disable prompter when inference"
)
parser.add_argument(
"--load_adapter",
action="store_true",
help="Load adapter from file instead of init randomly",
)
parser.add_argument(
"--disable_adapter", action="store_true", help="Disable the adapter modules"
)
parser.add_argument(
"--attn_impl", type=str, help="Specify the implementation of attention"
)
parser.add_argument(
"--sliding_window",
action="store_true",
help="Use sliding window attention (requires flash attention)",
)
parser.add_argument(
"--disable_cache",
action="store_true",
help="Disable cache when inference",
)
parser.add_argument(
"--cache_implementation",
type=str,
help="Specify the implementation of cache",
)
parser.add_argument(
"--fp16", action="store_true", help="Load base model in float16 precision"
)
parser.add_argument(
"--bf16", action="store_true", help="Load base model in bfloat16 precision"
)
parser.add_argument(
"--tf32", action="store_true", help="Use tfloat32 instead of float32 if available"
)
parser.add_argument(
"--load_8bit", action="store_true", help="Load base model with 8bit quantization"
)
parser.add_argument(
"--load_4bit", action="store_true", help="Load base model with 4bit quantization"
)
parser.add_argument("--device", type=str, help="Specify which GPU to be used")
parser.add_argument(
"--config", type=str, required=True, help="Path to finetune configuration"
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed in integer, default is 42"
)
parser.add_argument(
"--dir", type=str, default=".", help="Path to read or save checkpoints"
)
parser.add_argument("--disable_log", action="store_true", help="Disable logging")
parser.add_argument("--log_file", type=str, help="Save log to specific file")
parser.add_argument(
"--verbose", action="store_true", help="Show extra informations such as parameters"
)
parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite adapter model when older one existed",
)
parser.add_argument("--debug", action="store_true", help="Enabling debugging mode")
parser.add_argument(
"--deterministic",
action="store_true",
help="Use deterministic algorithms to improve the reproducibility",
)
args = parser.parse_args()
def query_yes_no(question, default="no"):
valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False}
if default is None:
prompt = " [y/n] "
elif default == "yes":
prompt = " [Y/n] "
elif default == "no":
prompt = " [y/N] "
else:
raise ValueError("invalid default answer: '%s'" % default)
while True:
sys.stdout.write(question + prompt)
choice = input().lower()
if default is not None and choice == "":
return valid[default]
elif choice in valid:
return valid[choice]
else:
sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n")
def load_base_model() -> Tuple[mlora.Tokenizer, mlora.LLMModel]:
logging.info("Initializing pre-trained model.")
model = mlora.LLMModel.from_pretrained(
name_or_path=args.base_model,
device=args.device,
attn_impl=args.attn_impl,
use_sliding_window=args.sliding_window,
bits=(8 if args.load_8bit else (4 if args.load_4bit else None)),
load_dtype=(
torch.bfloat16
if args.bf16
else (torch.float16 if args.fp16 else torch.float32)
),
)
tokenizer = mlora.Tokenizer(args.base_model)
return tokenizer, model
def init_adapter_config(
config: Dict[str, any],
llm_model: mlora.LLMModel,
) -> List[Union[mlora.GenerateConfig, mlora.TrainConfig]]:
config_list = []
if config["cutoff_len"] == -1:
config["cutoff_len"] = llm_model.max_seq_len_
logging.info(f"Setting cutoff_len to {llm_model.max_seq_len_} automatically.")
for lora_config in config["lora"]:
adapter_name = lora_config["name"]
adapter_path = f"{args.dir}{os.sep}{adapter_name}"
if not args.load_adapter and os.path.exists(adapter_path):
if args.overwrite:
logging.warning(
f"Overwriting existed adapter model file: {adapter_path}"
)
elif not query_yes_no(
f"Existed adapter model file detected: {adapter_path}\n" + "Overwrite?"
):
logging.info("User canceled training due to file conflict.")
exit(0)
if args.load_adapter:
llm_model.load_adapter(adapter_path, adapter_name)
else:
llm_model.init_adapter(mlora.lora_config_factory(lora_config))
if args.inference:
config_class = mlora.GenerateConfig(adapter_name=adapter_name)
if not args.disable_prompter:
config_class.prompt_template = lora_config.get("prompt", None)
config_list.append(config_class)
elif args.evaluate:
config_list.extend(mlora.EvaluateConfig.from_config(lora_config))
else:
config_list.append(mlora.TrainConfig.from_config(lora_config))
if args.verbose:
logging.info(config_list[-1].__dict__)
return config_list
def inference_callback(cur_pos, outputs):
print(f"POSITION: {cur_pos}")
for adapter_name, output in outputs.items():
print(f"{adapter_name} OUTPUT: {output[0]}")
def inference(
model: mlora.LLMModel,
tokenizer: mlora.Tokenizer,
configs: List[mlora.GenerateConfig],
concurrent_jobs: int,
):
while True:
input_raw = input("INPUT WITHOUT PROMPT: ")
if input_raw == "QUIT":
return
for config in configs:
config.prompts = [input_raw]
callback = None if args.disable_log else inference_callback
outputs = mlora.generate(
model,
tokenizer,
configs,
max_gen_len=128,
use_cache=not args.disable_cache,
concurrent_jobs=concurrent_jobs,
cache_implementation=args.cache_implementation,
stream_callback=callback,
)
print(f"\n{'='*10}\n")
print(f"PROMPT: {input_raw}")
for adapter_name, output in outputs.items():
print(f"{adapter_name} OUTPUT:")
print(output[0])
print(f"\n{'='*10}\n")
# Main Function
if __name__ == "__main__":
if args.debug:
torch.autograd.set_detect_anomaly(True)
if args.inference or args.evaluate:
args.load_adapter = True
inference_mode = True
else:
inference_mode = False
mlora.setup_logging("INFO", args.log_file)
mlora_backend = mlora.backend
if not mlora_backend.check_available():
exit(-1)
if args.attn_impl is None:
if (
inference_mode
and mlora_backend.device_name() == "cuda"
and is_flash_attn_2_available()
):
args.attn_impl = "flash_attn"
else:
args.attn_impl = "eager"
if args.device is None:
args.device = mlora.backend.default_device_name()
mlora_backend.use_deterministic_algorithms(args.deterministic)
mlora_backend.allow_tf32(args.tf32)
mlora_backend.manual_seed(args.seed)
with open(args.config, "r", encoding="utf8") as fp:
config = json.load(fp)
tokenizer, model = load_base_model()
adapters = init_adapter_config(config, model)
mlora_backend.empty_cache()
if os.getenv("MLORA_EVALUATE_MODE") is None:
logging.info("Using efficient operators.")
else:
logging.info("Using deterministic operators.")
if args.inference:
inference(
model=model,
tokenizer=tokenizer,
configs=adapters,
concurrent_jobs=config.get("inference_lora_simultaneously_num", 2),
)
elif args.evaluate:
mlora.evaluate(
model=model,
tokenizer=tokenizer,
configs=adapters,
max_concurrent_jobs=config.get("eval_lora_simultaneously_num", None),
retrying_steps=config.get("eval_rollback_retrying_steps", 20),
max_seq_len=config["cutoff_len"],
save_file=config.get("evaluate_result", None),
)
else:
mlora.train(
model=model,
tokenizer=tokenizer,
configs=adapters,
max_concurrent_jobs=config.get("train_lora_simultaneously_num", None),
strategy=config["train_strategy"],
cutoff_len=config["cutoff_len"],
save_step=config["save_step"],
save_dir=args.dir,
)