-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP]Add text tagging by prompt mapper op #408
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from typing import Dict | ||
|
||
from loguru import logger | ||
|
||
from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper | ||
from data_juicer.utils.availability_utils import AvailabilityChecking | ||
from data_juicer.utils.constant import Fields | ||
from data_juicer.utils.model_utils import get_model, prepare_model | ||
|
||
OP_NAME = 'text_tagging_by_prompt_mapper' | ||
|
||
with AvailabilityChecking(['torch', 'transformers', 'vllm'], OP_NAME): | ||
import torch | ||
import transformers # noqa: F401 | ||
import vllm # noqa: F401 | ||
|
||
# avoid hanging when calling model in multiprocessing | ||
torch.set_num_threads(1) | ||
|
||
|
||
DEFAULT_CLASSIFICATION_PROMPT = """ | ||
请对下面的example文本回复的任务类别进行检测,并进行分类。 | ||
备选的分类包括:{tag_list}。 | ||
只回复对应的分类,不回复其他内容。 | ||
example文本: | ||
{text} | ||
""" # noqa | ||
|
||
DEFAULT_CLASSIFICATION_LIST = [ | ||
'数学', '代码', '翻译', '角色扮演', '开放领域问答', '特定领域问答', '提取', '生成', '头脑风暴', '分类', | ||
'总结', '改写', '其他' | ||
] # noqa | ||
|
||
DEFAULT_IDENTITY_BINARY_PROMPT = """ | ||
检测下面的example文本的回复中是否包含人工智能模型的自我认知(例如表现出自己是一个AI人工助手)。 | ||
备选的分类包括:{tag_list}。 | ||
只回复对应的分类,不回复其他内容。 | ||
|
||
example文本: | ||
{text} | ||
""" # noqa | ||
|
||
DEFAULT_BINARY_LIST = ['是', '否'] | ||
|
||
|
||
# TODO: Extend LLM-based OPs into API-based implementation. | ||
@UNFORKABLE.register_module(OP_NAME) | ||
@OPERATORS.register_module(OP_NAME) | ||
class TextTaggingByPromptMapper(Mapper): | ||
""" | ||
Mapper to generate text tags using prompt with LLM. | ||
Recommended model list: [ | ||
'Qwen/Qwen2-7B-Instruct', | ||
'meta-llama/Meta-Llama-3.1-8B-Instruct', | ||
] | ||
Other opensourced models with good instruction following ability | ||
also works. | ||
""" | ||
|
||
_accelerator = 'cuda' | ||
|
||
def __init__(self, | ||
hf_model: str = 'Qwen/Qwen2-7B-Instruct', | ||
trust_remote_code: bool = False, | ||
prompt: str = DEFAULT_CLASSIFICATION_PROMPT, | ||
tag_list: str = DEFAULT_CLASSIFICATION_LIST, | ||
enable_vllm: bool = True, | ||
tensor_parallel_size: int = None, | ||
max_model_len: int = None, | ||
max_num_seqs: int = 256, | ||
sampling_params: Dict = {}, | ||
*args, | ||
**kwargs): | ||
""" | ||
Initialization method. | ||
:param hf_model: Hugginface model id. | ||
:param trust_remote_code: passed to transformers | ||
:param prompt: the prompt used to generate text tags. | ||
:param tag_list: the list of tagging output options. | ||
:param enable_vllm: Whether to use vllm for inference acceleration. | ||
:param tensor_parallel_size: It is only valid when enable_vllm is True. | ||
The number of GPUs to use for distributed execution with tensor | ||
parallelism. | ||
:param max_model_len: It is only valid when enable_vllm is True. | ||
Model context length. If unspecified, will be automatically | ||
derived from the model config. | ||
:param max_num_seqs: It is only valid when enable_vllm is True. | ||
Maximum number of sequences to be processed in a single iteration. | ||
:param sampling_params: Sampling parameters for text generation. | ||
e.g {'temperature': 0.9, 'top_p': 0.95} | ||
:param args: extra args | ||
:param kwargs: extra args | ||
|
||
The default data format parsed by this interface is as follows: | ||
Model Input: | ||
请对下面的example文本回复的任务类别进行检测,并进行分类。备选的分类包括:["数学","代码","翻译","角色扮演","开放领域问答","特定领域问答", "提取", "生成", "头脑风暴", "分类","总结","改写", "其他"]。只回复对应的分类,不回复其他内容。 | ||
example文本: | ||
{ | ||
"instruction": "找出方程 x2 - 3x = 0 的根。", | ||
"input": "", | ||
"output": "该方程可以写成 x(x-3)=0。\n\n根据乘法原理,x = 0或x - 3 = 0。\n\n因此,x1 = 0和x2 = 3是方程 x2 - 3x = 0 的两个根。" | ||
} | ||
Model Output: | ||
数学 | ||
""" # noqa | ||
|
||
super().__init__(*args, **kwargs) | ||
self.num_proc = 1 | ||
|
||
self.prompt = prompt | ||
self.tag_list = tag_list | ||
self.enable_vllm = enable_vllm | ||
|
||
if enable_vllm: | ||
import torch | ||
from vllm import SamplingParams | ||
|
||
assert torch.cuda.device_count() >= 1, 'must be executed in CUDA' | ||
if not tensor_parallel_size: | ||
tensor_parallel_size = torch.cuda.device_count() | ||
logger.info(f'Set tensor_parallel_size to \ | ||
{tensor_parallel_size} for vllm.') | ||
self.model_key = prepare_model( | ||
model_type='vllm', | ||
pretrained_model_name_or_path=hf_model, | ||
trust_remote_code=trust_remote_code, | ||
tensor_parallel_size=tensor_parallel_size, | ||
max_model_len=max_model_len, | ||
max_num_seqs=max_num_seqs) | ||
self.sampling_params = SamplingParams(**sampling_params) | ||
else: | ||
self.model_key = prepare_model( | ||
model_type='huggingface', | ||
pretrained_model_name_or_path=hf_model, | ||
trust_remote_code=trust_remote_code) | ||
self.sampling_params = sampling_params | ||
|
||
def process(self, sample, rank=None): | ||
model, processor = get_model(self.model_key, rank, self.use_cuda()) | ||
|
||
if self.enable_vllm: | ||
response = model.generate([ | ||
self.prompt.format(text=sample[self.text_key], | ||
tag_list=self.tag_list) | ||
], self.sampling_params) | ||
output = response[0].outputs[0].text | ||
else: | ||
inputs = processor([ | ||
self.prompt.format(text=sample[self.text_key], | ||
tag_list=self.tag_list) | ||
], | ||
return_tensors='pt').to(model.device) | ||
response = model.generate(**inputs, **self.sampling_params) | ||
output = processor.decode(response.cpu()[0], | ||
skip_special_tokens=True) | ||
|
||
text_tags = [] | ||
text_tags.append(output) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is better to |
||
sample[Fields.text_tags] = text_tags | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please refer to #423 for adding user-specified tag field name. |
||
|
||
return sample |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import unittest | ||
import json | ||
from data_juicer.ops.mapper.text_tagging_by_prompt_mapper import TextTaggingByPromptMapper, DEFAULT_CLASSIFICATION_PROMPT, DEFAULT_CLASSIFICATION_LIST | ||
from data_juicer.utils.constant import Fields | ||
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, | ||
DataJuicerTestCaseBase) | ||
|
||
def check_string_in_list(string_list, output): | ||
if not string_list: | ||
assert False, "输入的列表不能是空的" | ||
|
||
for string in string_list: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not directly check |
||
if string in output: | ||
return | ||
|
||
assert False, f"没有字符串在输出中" | ||
|
||
# Skip tests for this OP in the GitHub actions due to disk space limitation. | ||
# These tests have been tested locally. | ||
@SKIPPED_TESTS.register_module() | ||
class TextTaggingByPromptTest(DataJuicerTestCaseBase): | ||
text_key = 'text' | ||
|
||
def _run_tagging(self, samples, enable_vllm=False, sampling_params={}, **kwargs): | ||
op = TextTaggingByPromptMapper( | ||
hf_model='Qwen/Qwen2-7B-Instruct', | ||
prompt=DEFAULT_CLASSIFICATION_PROMPT, | ||
enable_vllm=enable_vllm, | ||
sampling_params=sampling_params, | ||
**kwargs | ||
) | ||
for sample in samples: | ||
result = op.process(sample) | ||
out_tag = result[Fields.text_tags] | ||
print(f'Output tag: {out_tag}') | ||
|
||
# test one output qa sample | ||
check_string_in_list(DEFAULT_CLASSIFICATION_LIST, out_tag) | ||
|
||
def test_tagging(self): | ||
samples = [ | ||
{ | ||
self.text_key: """{\n"instruction": "找出方程 x2 - 3x = 0 的根。",\n"input": "",\n"output": "该方程可以写成 x(x-3)=0。\n\n根据乘法原理,x = 0或x - 3 = 0。\n\n因此,x1 = 0和x2 = 3是方程 x2 - 3x = 0 的两个根。"\n}""" | ||
}] | ||
self._run_tagging(samples) | ||
|
||
def test_tagging_vllm(self): | ||
samples = [ | ||
{ | ||
self.text_key: """{\n"instruction": "找出方程 x2 - 3x = 0 的根。",\n"input": "",\n"output": "该方程可以写成 x(x-3)=0。\n\n根据乘法原理,x = 0或x - 3 = 0。\n\n因此,x1 = 0和x2 = 3是方程 x2 - 3x = 0 的两个根。"\n}""" | ||
}] | ||
self._run_tagging( | ||
samples, | ||
enable_vllm=True, | ||
max_model_len=1024, | ||
max_num_seqs=16, | ||
sampling_params={'temperature': 0.1, 'top_p': 0.95, 'max_tokens': 256}) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this OP run in multiple processes, especially without vllm? Please add more tests. |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better to add tests for |
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
enable_vllm
is False,num_proc=1
will force this OP to run in single process/GPU. Is this the desired behavior?