Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Add text tagging by prompt mapper op #408

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions data_juicer/ops/mapper/text_tagging_by_prompt_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import Dict

from loguru import logger

from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'text_tagging_by_prompt_mapper'

with AvailabilityChecking(['torch', 'transformers', 'vllm'], OP_NAME):
import torch
import transformers # noqa: F401
import vllm # noqa: F401

# avoid hanging when calling model in multiprocessing
torch.set_num_threads(1)


DEFAULT_CLASSIFICATION_PROMPT = """
请对下面的example文本回复的任务类别进行检测,并进行分类。
备选的分类包括:{tag_list}。
只回复对应的分类,不回复其他内容。
example文本:
{text}
""" # noqa

DEFAULT_CLASSIFICATION_LIST = [
'数学', '代码', '翻译', '角色扮演', '开放领域问答', '特定领域问答', '提取', '生成', '头脑风暴', '分类',
'总结', '改写', '其他'
] # noqa

DEFAULT_IDENTITY_BINARY_PROMPT = """
检测下面的example文本的回复中是否包含人工智能模型的自我认知(例如表现出自己是一个AI人工助手)。
备选的分类包括:{tag_list}。
只回复对应的分类,不回复其他内容。

example文本:
{text}
""" # noqa

DEFAULT_BINARY_LIST = ['是', '否']


# TODO: Extend LLM-based OPs into API-based implementation.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class TextTaggingByPromptMapper(Mapper):
"""
Mapper to generate text tags using prompt with LLM.
Recommended model list: [
'Qwen/Qwen2-7B-Instruct',
'meta-llama/Meta-Llama-3.1-8B-Instruct',
]
Other opensourced models with good instruction following ability
also works.
"""

_accelerator = 'cuda'

def __init__(self,
hf_model: str = 'Qwen/Qwen2-7B-Instruct',
trust_remote_code: bool = False,
prompt: str = DEFAULT_CLASSIFICATION_PROMPT,
tag_list: str = DEFAULT_CLASSIFICATION_LIST,
enable_vllm: bool = True,
tensor_parallel_size: int = None,
max_model_len: int = None,
max_num_seqs: int = 256,
sampling_params: Dict = {},
*args,
**kwargs):
"""
Initialization method.
:param hf_model: Hugginface model id.
:param trust_remote_code: passed to transformers
:param prompt: the prompt used to generate text tags.
:param tag_list: the list of tagging output options.
:param enable_vllm: Whether to use vllm for inference acceleration.
:param tensor_parallel_size: It is only valid when enable_vllm is True.
The number of GPUs to use for distributed execution with tensor
parallelism.
:param max_model_len: It is only valid when enable_vllm is True.
Model context length. If unspecified, will be automatically
derived from the model config.
:param max_num_seqs: It is only valid when enable_vllm is True.
Maximum number of sequences to be processed in a single iteration.
:param sampling_params: Sampling parameters for text generation.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param args: extra args
:param kwargs: extra args

The default data format parsed by this interface is as follows:
Model Input:
请对下面的example文本回复的任务类别进行检测,并进行分类。备选的分类包括:["数学","代码","翻译","角色扮演","开放领域问答","特定领域问答", "提取", "生成", "头脑风暴", "分类","总结","改写", "其他"]。只回复对应的分类,不回复其他内容。
example文本:
{
"instruction": "找出方程 x2 - 3x = 0 的根。",
"input": "",
"output": "该方程可以写成 x(x-3)=0。\n\n根据乘法原理,x = 0或x - 3 = 0。\n\n因此,x1 = 0和x2 = 3是方程 x2 - 3x = 0 的两个根。"
}
Model Output:
数学
""" # noqa

super().__init__(*args, **kwargs)
self.num_proc = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If enable_vllm is False, num_proc=1 will force this OP to run in single process/GPU. Is this the desired behavior?


self.prompt = prompt
self.tag_list = tag_list
self.enable_vllm = enable_vllm

if enable_vllm:
import torch
from vllm import SamplingParams

assert torch.cuda.device_count() >= 1, 'must be executed in CUDA'
if not tensor_parallel_size:
tensor_parallel_size = torch.cuda.device_count()
logger.info(f'Set tensor_parallel_size to \
{tensor_parallel_size} for vllm.')
self.model_key = prepare_model(
model_type='vllm',
pretrained_model_name_or_path=hf_model,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs)
self.sampling_params = SamplingParams(**sampling_params)
else:
self.model_key = prepare_model(
model_type='huggingface',
pretrained_model_name_or_path=hf_model,
trust_remote_code=trust_remote_code)
self.sampling_params = sampling_params

def process(self, sample, rank=None):
model, processor = get_model(self.model_key, rank, self.use_cuda())

if self.enable_vllm:
response = model.generate([
self.prompt.format(text=sample[self.text_key],
tag_list=self.tag_list)
], self.sampling_params)
output = response[0].outputs[0].text
else:
inputs = processor([
self.prompt.format(text=sample[self.text_key],
tag_list=self.tag_list)
],
return_tensors='pt').to(model.device)
response = model.generate(**inputs, **self.sampling_params)
output = processor.decode(response.cpu()[0],
skip_special_tokens=True)

text_tags = []
text_tags.append(output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to strip output

sample[Fields.text_tags] = text_tags
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to #423 for adding user-specified tag field name.


return sample
3 changes: 3 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class Fields(object):
video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__'
video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__'

# text_tags
text_tags = DEFAULT_PREFIX + 'text_tags__'

# the name of the original file from which this sample was derived.
source_file = DEFAULT_PREFIX + 'source_file__'

Expand Down
61 changes: 61 additions & 0 deletions tests/ops/mapper/test_text_tagging_by_prompt_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import unittest
import json
from data_juicer.ops.mapper.text_tagging_by_prompt_mapper import TextTaggingByPromptMapper, DEFAULT_CLASSIFICATION_PROMPT, DEFAULT_CLASSIFICATION_LIST
from data_juicer.utils.constant import Fields
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
DataJuicerTestCaseBase)

def check_string_in_list(string_list, output):
if not string_list:
assert False, "输入的列表不能是空的"

for string in string_list:
Copy link
Collaborator

@drcege drcege Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not directly check output in string_list?

if string in output:
return

assert False, f"没有字符串在输出中"

# Skip tests for this OP in the GitHub actions due to disk space limitation.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class TextTaggingByPromptTest(DataJuicerTestCaseBase):
text_key = 'text'

def _run_tagging(self, samples, enable_vllm=False, sampling_params={}, **kwargs):
op = TextTaggingByPromptMapper(
hf_model='Qwen/Qwen2-7B-Instruct',
prompt=DEFAULT_CLASSIFICATION_PROMPT,
enable_vllm=enable_vllm,
sampling_params=sampling_params,
**kwargs
)
for sample in samples:
result = op.process(sample)
out_tag = result[Fields.text_tags]
print(f'Output tag: {out_tag}')

# test one output qa sample
check_string_in_list(DEFAULT_CLASSIFICATION_LIST, out_tag)

def test_tagging(self):
samples = [
{
self.text_key: """{\n"instruction": "找出方程 x2 - 3x = 0 的根。",\n"input": "",\n"output": "该方程可以写成 x(x-3)=0。\n\n根据乘法原理,x = 0或x - 3 = 0。\n\n因此,x1 = 0和x2 = 3是方程 x2 - 3x = 0 的两个根。"\n}"""
}]
self._run_tagging(samples)

def test_tagging_vllm(self):
samples = [
{
self.text_key: """{\n"instruction": "找出方程 x2 - 3x = 0 的根。",\n"input": "",\n"output": "该方程可以写成 x(x-3)=0。\n\n根据乘法原理,x = 0或x - 3 = 0。\n\n因此,x1 = 0和x2 = 3是方程 x2 - 3x = 0 的两个根。"\n}"""
}]
self._run_tagging(
samples,
enable_vllm=True,
max_model_len=1024,
max_num_seqs=16,
sampling_params={'temperature': 0.1, 'top_p': 0.95, 'max_tokens': 256})

Copy link
Collaborator

@drcege drcege Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this OP run in multiple processes, especially without vllm? Please add more tests.


Copy link
Collaborator

@drcege drcege Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to add tests for tensor_parallel_size.

if __name__ == '__main__':
unittest.main()
Loading