Skip to content

Commit

Permalink
add text tagging by prompt mapper op
Browse files Browse the repository at this point in the history
  • Loading branch information
问昊 committed Aug 30, 2024
1 parent 22834ba commit 2920174
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 0 deletions.
161 changes: 161 additions & 0 deletions data_juicer/ops/mapper/text_tagging_by_prompt_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import Dict

from loguru import logger

from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'text_tagging_by_prompt_mapper'

with AvailabilityChecking(['torch', 'transformers', 'vllm'], OP_NAME):
import torch
import transformers # noqa: F401
import vllm # noqa: F401

# avoid hanging when calling model in multiprocessing
torch.set_num_threads(1)


DEFAULT_CLASSIFICATION_PROMPT = """
请对下面的example文本回复的任务类别进行检测,并进行分类。
备选的分类包括:{tag_list}。
只回复对应的分类,不回复其他内容。
example文本:
{text}
""" # noqa

DEFAULT_CLASSIFICATION_LIST = [
'数学', '代码', '翻译', '角色扮演', '开放领域问答', '特定领域问答', '提取', '生成', '头脑风暴', '分类',
'总结', '改写', '其他'
] # noqa

DEFAULT_IDENTITY_BINARY_PROMPT = """
检测下面的example文本的回复中是否包含人工智能模型的自我认知(例如表现出自己是一个AI人工助手)。
备选的分类包括:{tag_list}。
只回复对应的分类,不回复其他内容。
example文本:
{text}
""" # noqa

DEFAULT_BINARY_LIST = ['是', '否']


# TODO: Extend LLM-based OPs into API-based implementation.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class TextTaggingByPromptMapper(Mapper):
"""
Mapper to generate text tags using prompt with LLM.
Recommended model list: [
'Qwen/Qwen2-7B-Instruct',
'meta-llama/Meta-Llama-3.1-8B-Instruct',
]
Other opensourced models with good instruction following ability
also works.
"""

_accelerator = 'cuda'

def __init__(self,
hf_model: str = 'Qwen/Qwen2-7B-Instruct',
trust_remote_code: bool = False,
prompt: str = DEFAULT_CLASSIFICATION_PROMPT,
tag_list: str = DEFAULT_CLASSIFICATION_LIST,
enable_vllm: bool = True,
tensor_parallel_size: int = None,
max_model_len: int = None,
max_num_seqs: int = 256,
sampling_params: Dict = {},
*args,
**kwargs):
"""
Initialization method.
:param hf_model: Hugginface model id.
:param trust_remote_code: passed to transformers
:param prompt: the prompt used to generate text tags.
:param tag_list: the list of tagging output options.
:param enable_vllm: Whether to use vllm for inference acceleration.
:param tensor_parallel_size: It is only valid when enable_vllm is True.
The number of GPUs to use for distributed execution with tensor
parallelism.
:param max_model_len: It is only valid when enable_vllm is True.
Model context length. If unspecified, will be automatically
derived from the model config.
:param max_num_seqs: It is only valid when enable_vllm is True.
Maximum number of sequences to be processed in a single iteration.
:param sampling_params: Sampling parameters for text generation.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param args: extra args
:param kwargs: extra args
The default data format parsed by this interface is as follows:
Model Input:
请对下面的example文本回复的任务类别进行检测,并进行分类。备选的分类包括:["数学","代码","翻译","角色扮演","开放领域问答","特定领域问答", "提取", "生成", "头脑风暴", "分类","总结","改写", "其他"]。只回复对应的分类,不回复其他内容。
example文本:
{
"instruction": "找出方程 x2 - 3x = 0 的根。",
"input": "",
"output": "该方程可以写成 x(x-3)=0。\n\n根据乘法原理,x = 0或x - 3 = 0。\n\n因此,x1 = 0和x2 = 3是方程 x2 - 3x = 0 的两个根。"
}
Model Output:
数学
""" # noqa

super().__init__(*args, **kwargs)
self.num_proc = 1

self.prompt = prompt
self.tag_list = tag_list
self.enable_vllm = enable_vllm

if enable_vllm:
import torch
from vllm import SamplingParams

assert torch.cuda.device_count() >= 1, 'must be executed in CUDA'
if not tensor_parallel_size:
tensor_parallel_size = torch.cuda.device_count()
logger.info(f'Set tensor_parallel_size to \
{tensor_parallel_size} for vllm.')
self.model_key = prepare_model(
model_type='vllm',
pretrained_model_name_or_path=hf_model,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs)
self.sampling_params = SamplingParams(**sampling_params)
else:
self.model_key = prepare_model(
model_type='huggingface',
pretrained_model_name_or_path=hf_model,
trust_remote_code=trust_remote_code)
self.sampling_params = sampling_params

def process(self, sample, rank=None):
model, processor = get_model(self.model_key, rank, self.use_cuda())

if self.enable_vllm:
response = model.generate([
self.prompt.format(text=sample[self.text_key],
tag_list=self.tag_list)
], self.sampling_params)
output = response[0].outputs[0].text
else:
inputs = processor([
self.prompt.format(text=sample[self.text_key],
tag_list=self.tag_list)
],
return_tensors='pt').to(model.device)
response = model.generate(**inputs, **self.sampling_params)
output = processor.decode(response.cpu()[0],
skip_special_tokens=True)

text_tags = []
text_tags.append(output)
sample[Fields.text_tags] = text_tags

return sample
3 changes: 3 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class Fields(object):
video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__'
video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__'

# text_tags
text_tags = DEFAULT_PREFIX + 'text_tags__'

# the name of the original file from which this sample was derived.
source_file = DEFAULT_PREFIX + 'source_file__'

Expand Down
61 changes: 61 additions & 0 deletions tests/ops/mapper/test_text_tagging_by_prompt_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import unittest
import json
from data_juicer.ops.mapper.text_tagging_by_prompt_mapper import TextTaggingByPromptMapper, DEFAULT_CLASSIFICATION_PROMPT, DEFAULT_CLASSIFICATION_LIST
from data_juicer.utils.constant import Fields
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
DataJuicerTestCaseBase)

def check_string_in_list(string_list, output):
if not string_list:
assert False, "输入的列表不能是空的"

for string in string_list:
if string in output:
return

assert False, f"没有字符串在输出中"

# Skip tests for this OP in the GitHub actions due to disk space limitation.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class TextTaggingByPromptTest(DataJuicerTestCaseBase):
text_key = 'text'

def _run_tagging(self, samples, enable_vllm=False, sampling_params={}, **kwargs):
op = TextTaggingByPromptMapper(
hf_model='Qwen/Qwen2-7B-Instruct',
prompt=DEFAULT_CLASSIFICATION_PROMPT,
enable_vllm=enable_vllm,
sampling_params=sampling_params,
**kwargs
)
for sample in samples:
result = op.process(sample)
out_tag = result[Fields.text_tags]
print(f'Output tag: {out_tag}')

# test one output qa sample
check_string_in_list(DEFAULT_CLASSIFICATION_LIST, out_tag)

def test_tagging(self):
samples = [
{
self.text_key: """{\n"instruction": "找出方程 x2 - 3x = 0 的根。",\n"input": "",\n"output": "该方程可以写成 x(x-3)=0。\n\n根据乘法原理,x = 0或x - 3 = 0。\n\n因此,x1 = 0和x2 = 3是方程 x2 - 3x = 0 的两个根。"\n}"""
}]
self._run_tagging(samples)

def test_tagging_vllm(self):
samples = [
{
self.text_key: """{\n"instruction": "找出方程 x2 - 3x = 0 的根。",\n"input": "",\n"output": "该方程可以写成 x(x-3)=0。\n\n根据乘法原理,x = 0或x - 3 = 0。\n\n因此,x1 = 0和x2 = 3是方程 x2 - 3x = 0 的两个根。"\n}"""
}]
self._run_tagging(
samples,
enable_vllm=True,
max_model_len=1024,
max_num_seqs=16,
sampling_params={'temperature': 0.1, 'top_p': 0.95, 'max_tokens': 256})


if __name__ == '__main__':
unittest.main()

0 comments on commit 2920174

Please sign in to comment.