-
Notifications
You must be signed in to change notification settings - Fork 187
/
generate_qa_from_examples_mapper.py
270 lines (236 loc) · 11 KB
/
generate_qa_from_examples_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
import json
import random
import re
from typing import Dict, Optional
from loguru import logger
from pydantic import PositiveInt
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
from ..base_op import OPERATORS, Mapper
torch = LazyLoader('torch', 'torch')
vllm = LazyLoader('vllm', 'vllm')
rouge = LazyLoader('rouge', 'rouge')
OP_NAME = 'generate_qa_from_examples_mapper'
# TODO: Extend LLM-based OPs into API-based implementation.
@OPERATORS.register_module(OP_NAME)
class GenerateQAFromExamplesMapper(Mapper):
"""
Mapper to generate question and answer pairs from examples.
You should configure an empty dataset in your yaml config file:
```
generated_dataset_config:
type: 'EmptyFormatter' # use `RayEmptyFormatter` when enable ray
length: ${The number of generated samples}
feature_keys: ${text key}
```
The number of samples generated is determined by
the length of the empty dataset.
"""
DEFAULT_SYSTEM_PROMPT = (
'请你仔细观察多个示例数据的输入和输出,按照你的理解,总结出相应规矩,然后写出一个新的【问题】和【回答】。'
'注意,新生成的【问题】和【回答】需要满足如下要求:\n'
'1. 生成的【问题】和【回答】不能与输入的【问题】和【回答】一致,但是需要保持格式相同。\n'
'2. 生成的【问题】不一定要局限于输入【问题】的话题或领域,生成的【回答】需要正确回答生成的【问题】。\n'
'3. 提供的【问题】和【回答】可能是多轮对话,生成的【问题】和【回答】也可以是多轮,但是需要保持格式相同。\n'
'4. 生成的【问题】和【回答】必须成对出现,而且【问题】需要在【回答】之前。\n')
DEFAULT_INPUT_TEMPLATE = '{}'
DEFAULT_EXAMPLE_TEMPLATE = '\n如下是一条示例数据:\n{}'
DEFAULT_QA_PAIR_TEMPLATE = '【问题】\n{}\n【回答】\n{}\n'
DEFAULT_OUTPUT_PATTERN = r'【问题】(.*?)【回答】(.*?)(?=【问题】|$)'
_accelerator = 'cuda'
def __init__(self,
hf_model: str = 'Qwen/Qwen2.5-7B-Instruct',
*,
seed_file: str = '',
example_num: PositiveInt = 3,
similarity_threshold: float = 0.7,
system_prompt: Optional[str] = None,
input_template: Optional[str] = None,
example_template: Optional[str] = None,
qa_pair_template: Optional[str] = None,
output_pattern: Optional[str] = None,
enable_vllm: bool = False,
model_params: Optional[Dict] = None,
sampling_params: Optional[Dict] = None,
**kwargs):
"""
Initialization method.
:param hf_model: Hugginface model ID.
:param seed_file: Path to the seed file in chatml format.
:param example_num: The number of selected examples.
Randomly select N examples from "seed_file" and
put them into prompt as QA examples.
:param similarity_threshold: The similarity score threshold
between the generated samples and the seed examples.
Range from 0 to 1. Samples with similarity score less than
this threshold will be kept.
:param system_prompt: System prompt for guiding the generation task.
:param input_template: Template for building the input prompt. It must
include one placeholder '{}', which will be replaced by
`example_num` formatted examples defined by `example_template`.
:param example_template: Template for formatting one QA example. It
must include one placeholder '{}', which will be replaced by one
formatted qa_pair.
:param qa_pair_template: Template for formatting a single QA pair
within each example. Must include two placeholders '{}' for the
question and answer.
:param output_pattern: Regular expression pattern to extract questions
and answers from model response.
:param enable_vllm: Whether to use vllm for inference acceleration.
:param model_params: Parameters for initializing the model.
:param sampling_params: Sampling parameters for text generation.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
if not seed_file:
raise ValueError(
'Please provide `seed_file` in chatml format.'
'Example: data-juicer/demos/data/demo-dataset-chatml.jsonl')
self.seed_file = seed_file
self.example_num = example_num
self.similarity_threshold = similarity_threshold
self.similarity_type = 'rouge_l'
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
self.example_template = example_template or self.DEFAULT_EXAMPLE_TEMPLATE # noqa: E501
self.qa_pair_template = qa_pair_template or \
self.DEFAULT_QA_PAIR_TEMPLATE
self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN
self.enable_vllm = enable_vllm
model_params = model_params or {}
sampling_params = sampling_params or {}
if enable_vllm:
assert torch.cuda.device_count() >= 1, 'must be executed in CUDA'
# cannot initialize vllm replicas on different GPUs
self.num_proc = 1
if model_params.get('tensor_parallel_size') is None:
tensor_parallel_size = torch.cuda.device_count()
logger.info(f'Set tensor_parallel_size to \
{tensor_parallel_size} for vllm.')
model_params['tensor_parallel_size'] = tensor_parallel_size
self.model_key = prepare_model(
model_type='vllm',
pretrained_model_name_or_path=hf_model,
**model_params)
self.sampling_params = vllm.SamplingParams(**sampling_params)
else:
self.model_key = prepare_model(
model_type='huggingface',
pretrained_model_name_or_path=hf_model,
return_pipe=True,
**model_params)
self.sampling_params = sampling_params
self.seed_qa_samples = self._load_seed_qa_samples()
if len(self.seed_qa_samples) == 0:
raise ValueError('No QA data was parsed from the seed file!')
def _load_seed_qa_samples(self):
"""Load QA pairs from chatml format file."""
qa_samples = []
with open(self.seed_file, encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
qa_pairs = self._parse_chatml_str(line)
if len(qa_pairs) > 0:
qa_samples.append(qa_pairs)
return qa_samples
def _sample_to_str(self, qa_sample):
return '\n'.join(['\n'.join(qa_pair) for qa_pair in qa_sample]) + '\n'
def _max_rouge_l_score(self, hypothesis, references):
r = rouge.Rouge()
max_score = 0.0
hyp_str = self._sample_to_str(hypothesis)
for reference in references:
ref_str = self._sample_to_str(reference)
scores = r.get_scores(hyp_str, ref_str)
rouge_l_score = scores[0]['rouge-l']['f']
if rouge_l_score > max_score:
max_score = rouge_l_score
return max_score
def _parse_chatml_str(self, sample_str):
user_input = None
assistant_output = None
qa_pairs = []
data = json.loads(sample_str)
for message in data['messages']:
role = message['role']
content = message['content']
if role == 'user':
user_input = content
elif role == 'assistant':
assistant_output = content
qa_pairs.append((user_input, assistant_output))
return qa_pairs
def build_input(self, qa_examples):
def format_qa_pairs(qa_example):
return ''.join([
self.qa_pair_template.format(q, a) for q, a in qa_example
if q and a
])
formatted_examples = ''.join([
self.example_template.format(qa_pairs=format_qa_pairs(qa_example))
for qa_example in qa_examples
])
input_prompt = self.input_template.format(examples=formatted_examples)
return input_prompt
def parse_output(self, raw_output):
logger.debug(raw_output)
output_qa_pairs = []
matches = re.findall(self.output_pattern, raw_output, re.DOTALL)
for match in matches:
question, answer = match
output_qa_pairs.append((question.strip(), answer.strip()))
return output_qa_pairs
def process_single(self, sample, rank=None):
model, _ = get_model(self.model_key, rank, self.use_cuda())
random_qa_samples = random.sample(self.seed_qa_samples,
self.example_num)
input_prompt = self.build_input(random_qa_samples)
messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': input_prompt
}]
if self.enable_vllm:
response = model.chat(messages, self.sampling_params)
output = response[0].outputs[0].text
else:
# model is pipe
response = model(messages,
return_full_text=False,
**self.sampling_params)
output = response[0]['generated_text']
output_qa_pairs = self.parse_output(output)
if len(output_qa_pairs) == 0:
logger.warning('Parse model response error! '
'No data generated for the current response!')
sample.update({
self.query_key: '',
self.response_key: '',
self.history_key: self.empty_history()
})
return sample
if self.similarity_type == 'rouge_l':
sim_score = self._max_rouge_l_score(output_qa_pairs,
random_qa_samples)
else:
raise ValueError(
f'Not support similarity type "{self.similarity_type}"!')
if sim_score <= self.similarity_threshold:
query, response = output_qa_pairs[-1]
history = output_qa_pairs[:-1]
if len(history) == 0:
history = self.empty_history()
else:
query = response = ''
history = self.empty_history()
logger.info('Filter this generated sample due to similarity.')
sample.update({
self.query_key: query,
self.response_key: response,
self.history_key: history
})
return sample