forked from SCIR-HI/Huatuo-Llama-Med-Chinese
-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer_literature.py
127 lines (112 loc) · 4.13 KB
/
infer_literature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import sys
import json
import fire
import gradio as gr
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, AutoModelForCausalLM, AutoTokenizer
from utils.prompter import Prompter
if torch.cuda.is_available():
device = "cuda"
def load_instruction(instruct_dir):
input_data = []
with open(instruct_dir, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
d = json.loads(line)
input_data.append(d)
return input_data
def main(
load_8bit: bool = False,
base_model: str = "",
# the infer data, if not exists, infer the default instructions in code
single_or_multi: str = "",
use_lora: bool = True,
lora_weights: str = "tloen/alpaca-lora-7b",
# The prompt template to use, will default to med_template.
prompt_template: str = "med_template",
):
prompter = Prompter(prompt_template)
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
if use_lora:
print(f"using lora {lora_weights}")
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
def evaluate(
instruction,
input=None,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=4,
max_new_tokens=256,
**kwargs,
):
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
return prompter.get_response(output)
if single_or_multi == "multi":
response=""
instruction=""
for _ in range(0,5):
inp=input("请输入:")
inp="<user>: " + inp
instruction=instruction+inp
response=evaluate(instruction)
response=response.replace('\n','')
print("Response:", response)
instruction= instruction + " <bot>: " + response
elif single_or_multi == "single":
for instruction in [
"肝癌是什么?有哪些症状和迹象?",
"肝癌是如何诊断的?有哪些检查和测试可以帮助诊断?",
"Sorafenib是一种口服的多靶点酪氨酸激酶抑制剂,它的作用机制是什么?",
"Regorafenib是一种口服的多靶点酪氨酸激酶抑制剂,它的作用机制是什么?它和Sorafenib有什么不同?",
"肝癌药物治疗的副作用有哪些?如何缓解这些副作用?",
"肝癌药物治疗的费用高昂,如何降低治疗的经济负担?",
"我想了解一下β-谷甾醇是否可作为肝癌的治疗药物",
"能介绍一下最近Hsa_circ_0008583在肝细胞癌治疗中的潜在应用的研究么?"
]:
print("instruction:",instruction)
instruction="<user>: "+instruction
print("Response:", evaluate(instruction))
if __name__ == "__main__":
fire.Fire(main)