-
Notifications
You must be signed in to change notification settings - Fork 826
/
test_falcon_gsm8k.py
113 lines (94 loc) · 3.09 KB
/
test_falcon_gsm8k.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import re
import json
import torch
import random
import transformers
from tqdm import tqdm
from datasets import DatasetDict
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from prompt_pattern import PROMPT, STOP_WORD
def clean(content):
pattern = '<<.+>>'
result = re.findall(pattern, content)
for t in result:
content = content.replace(t, '')
content = content.replace('\n', '. ')
return content
def load_multi_line_json(f):
data = ''
all_data = []
raw_data =f.readlines()
for line in raw_data:
data = data + line
if (line.startswith('}')):
all_data.append(json.loads(data))
data = ''
return all_data
model = "/media/public/models/huggingface/falcon/falcon-7b"
set_seed(2023)
random.seed(2023)
pattern = PROMPT['VANILLA']
stop = STOP_WORD['VANILLA']
train_dataset = DatasetDict.from_json('/mnt/chenzhipeng/llm_data/pretrain_data/MathInstruction/gsm8k.json', field='train')
ids = random.sample([i for i in range(len(train_dataset))], 3)
demo = ''
for idx in ids:
data = train_dataset[idx]
problem = data['question']
solution = data['answer']
answer = solution.split('####')[-1].strip()
solution = clean(solution.split('####')[0].strip())
demo = demo + pattern.format(problem, f'{solution} The answer is {answer}')
print(demo)
test_dataset = DatasetDict.from_json('/mnt/chenzhipeng/llm_data/pretrain_data/MathInstruction/gsm8k.json', field='test')
print(test_dataset)
eval_problem = []
test_data = []
with open('result/falcon_7b-gsm8k_pipeline-3shot.json') as fin:
eval_dataset = load_multi_line_json(fin)
for data in eval_dataset:
eval_problem.append(data['question'])
for data in test_dataset:
if (data['question'] not in eval_problem):
test_data.append(data)
print(data['question'])
test_dataset = test_data
tokenizer = AutoTokenizer.from_pretrained(model, padding_side='left')
pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
batch_size=6
)
pipeline.tokenizer.pad_token_id = tokenizer.eos_token_id
fout = open('result/falcon_7b-gsm8k_pipeline-3shot.json', 'a')
inputs = []
origin_data = []
def make_query():
global inputs, origin_data, fout
sequences = pipeline(
inputs,
max_length=1024,
do_sample=False,
top_k=1,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
)
for pred, data in zip(sequences, origin_data):
tmp_data = data
tmp_data['real_ans'] = data['answer'].split('####')[-1].strip()
tmp_data['pred_ans'] = pred[0]['generated_text']
fout.write(json.dumps(tmp_data, indent=4) + '\n')
origin_data = []
inputs = []
for step, data in enumerate(tqdm(test_dataset)):
inputs.append(demo + 'Problem: ' + data['question'] + "Solution: Let's think step by step")
origin_data.append(data)
if (len(inputs) == 6):
make_query()
if (len(inputs) != 0):
make_query()
fout.close()