From 931d6478defcec38cbecd2d04f3b8a3a358f22f3 Mon Sep 17 00:00:00 2001 From: Qirui-jiao Date: Thu, 22 Aug 2024 16:29:58 +0800 Subject: [PATCH 01/10] upload mllm_mapper --- configs/config_all.yaml | 2 + data_juicer/ops/mapper/__init__.py | 4 +- data_juicer/ops/mapper/mllm_mapper.py | 86 +++++++++++++++++++++++++++ docs/Operators.md | 3 +- docs/Operators_ZH.md | 3 +- tests/ops/mapper/test_mllm_mapper.py | 31 ++++++++++ 6 files changed, 126 insertions(+), 3 deletions(-) create mode 100644 data_juicer/ops/mapper/mllm_mapper.py create mode 100644 tests/ops/mapper/test_mllm_mapper.py diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 8273a30f4..e15fddf87 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -100,6 +100,8 @@ process: - image_face_blur_mapper: # blur faces detected in images blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] radius: 2 # radius of blur kernel + - mllm_mapper: # use MLLMs for visual question answering tasks + sampling_params: {} # sampling hyperparameters for text generation - nlpaug_en_mapper: # simply augment texts in English based on the nlpaug library sequential: false # whether combine all augmentation methods to a sequence. If it's True, a sample will be augmented by all opened augmentation methods sequentially. If it's False, each opened augmentation method would generate its augmented samples independently. aug_num: 1 # number of augmented samples to be generated. If `sequential` is True, there will be total aug_num augmented samples generated. If it's False, there will be (aug_num * #opened_aug_method) augmented samples generated. diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 5213498e9..c0714dc0d 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -4,7 +4,7 @@ clean_ip_mapper, clean_links_mapper, expand_macro_mapper, extract_qa_mapper, fix_unicode_mapper, image_blur_mapper, image_captioning_from_gpt4v_mapper, image_captioning_mapper, - image_diffusion_mapper, image_face_blur_mapper, + image_diffusion_mapper, image_face_blur_mapper, mllm_mapper, nlpaug_en_mapper, nlpcda_zh_mapper, punctuation_normalization_mapper, remove_bibliography_mapper, remove_comments_mapper, remove_header_mapper, @@ -39,6 +39,7 @@ from .image_captioning_mapper import ImageCaptioningMapper from .image_diffusion_mapper import ImageDiffusionMapper from .image_face_blur_mapper import ImageFaceBlurMapper +from .mllm_mapper import MllmMapper from .nlpaug_en_mapper import NlpaugEnMapper from .nlpcda_zh_mapper import NlpcdaZhMapper from .punctuation_normalization_mapper import PunctuationNormalizationMapper @@ -118,6 +119,7 @@ 'AudioFFmpegWrappedMapper', 'VideoSplitByDurationMapper', 'VideoFaceBlurMapper', + 'MllmMapper' ] # yapf: enable diff --git a/data_juicer/ops/mapper/mllm_mapper.py b/data_juicer/ops/mapper/mllm_mapper.py new file mode 100644 index 000000000..772ab52d5 --- /dev/null +++ b/data_juicer/ops/mapper/mllm_mapper.py @@ -0,0 +1,86 @@ +from typing import Dict + +import torch + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.ops.op_fusion import LOADED_IMAGES +from data_juicer.utils.mm_utils import load_image +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'mllm_mapper' +torch.set_num_threads(1) + + +@LOADED_IMAGES.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class MllmMapper(Mapper): + """Mapper to optimize instruction. + Recommended model list: [ + liuhaotia/llava-v1.6-vicuna-7b + ] + """ + _accelerator = 'cuda' + + def __init__(self, + hf_model: str = 'liuhaotia/llava-v1.6-vicuna-7b', + max_new_tokens=256, + sampling_params: Dict = {}, + *args, + **kwargs): + """ + Initialization method. + :param hf_model: Hugginface model id. + :param sampling_params: Sampling hyperparameters for text generation. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, num_proc=1, **kwargs) + + self.hf_model = hf_model + self.model_key = prepare_model(model_type='huggingface', + pretrained_model_name_or_path=hf_model) + self.sampling_params = sampling_params + self.max_new_tokens = max_new_tokens + + def process(self, sample=None, rank=None): + + model, processor = get_model(self.model_key, rank=rank, use_cuda=True) + + # there is no image in this sample + if self.image_key not in sample or not sample[self.image_key]: + return sample + + # load images + loaded_image_key = sample[self.image_key] + image = load_image(loaded_image_key) + + conversation = [ + { + 'role': + 'user', + 'content': [ + { + 'type': 'text', + 'text': sample[self.text_key] + }, + { + 'type': 'image' + }, + ], + }, + ] + prompt = processor.apply_chat_template(conversation, + add_generation_prompt=True) + + inputs = processor(images=image, text=prompt, + return_tensors='pt').to(model.device) + + response = model.generate(**inputs, + max_new_tokens=self.max_new_tokens, + **self.sampling_params) + output = processor.decode(response.cpu()[0], skip_special_tokens=True) + + sample[self.text_key] = output + + return sample diff --git a/docs/Operators.md b/docs/Operators.md index a35210161..4e80c9582 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 43 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 44 | Edits and transforms samples | | [ Filter ]( #filter ) | 41 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -64,6 +64,7 @@ All the specific operators are listed below, each featured with several capabili | image_captioning_mapper | Multimodal | - | generate samples whose captions are generated based on another model (such as blip2) and the figure within the original sample | | image_diffusion_mapper | Multimodal | - | Generate and augment images by stable diffusion model | | image_face_blur_mapper | Image | - | Blur faces detected in images | +| mllm_mapper | Multimodal | en, zh | Use multimodal large language models for image-text question answering tasks | | nlpaug_en_mapper | General | en | Simply augments texts in English based on the `nlpaug` library | | nlpcda_zh_mapper | General | zh | Simply augments texts in Chinese based on the `nlpcda` library | | punctuation_normalization_mapper | General | en, zh | Normalizes various Unicode punctuations to their ASCII equivalents | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 855d109a7..03850ae8d 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 43 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 44 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 41 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -63,6 +63,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | image_captioning_mapper | Multimodal | - | 生成样本,其标题是根据另一个辅助模型(例如 blip2)和原始样本中的图形生成的。 | | image_diffusion_mapper | Multimodal | - | 用stable diffusion生成图像,对图像进行增强 | | image_face_blur_mapper | Image | - | 对图像中的人脸进行模糊处理 | +| mllm_mapper | Multimodal | en, zh | 使用多模态大语言模型执行图文问答任务 | | nlpaug_en_mapper | General | en | 使用`nlpaug`库对英语文本进行简单增强 | | nlpcda_zh_mapper | General | zh | 使用`nlpcda`库对中文文本进行简单增强 | | punctuation_normalization_mapper | General | en, zh | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | diff --git a/tests/ops/mapper/test_mllm_mapper.py b/tests/ops/mapper/test_mllm_mapper.py new file mode 100644 index 000000000..571259c1f --- /dev/null +++ b/tests/ops/mapper/test_mllm_mapper.py @@ -0,0 +1,31 @@ +import unittest +from data_juicer.ops.mapper.mllm_mapper import MllmMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class MllmMapperTest(DataJuicerTestCaseBase): + + text_key = 'text' + image_key = "images" + + def _run_mllm(self): + op = MllmMapper( + hf_model='llava-v1.6-vicuna-7b-hf' + ) + + samples = [ + {self.text_key: 'Describe this image.', self.image_key: "./crayon.jpg"}, + ] + + for sample in samples: + result = op.process(sample) + print(f'Output results: {result}') + + def test_mllm(self): + self._run_mllm() + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From c1846ec30261cdb1cff92f96ae5e54e438a944f9 Mon Sep 17 00:00:00 2001 From: Qirui-jiao <156628817+Qirui-jiao@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:38:17 +0800 Subject: [PATCH 02/10] Update config_all.yaml --- configs/config_all.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index e15fddf87..539eee87e 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -101,6 +101,7 @@ process: blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] radius: 2 # radius of blur kernel - mllm_mapper: # use MLLMs for visual question answering tasks + max_new_tokens: 256 # the maximum number of new tokens generated by the model sampling_params: {} # sampling hyperparameters for text generation - nlpaug_en_mapper: # simply augment texts in English based on the nlpaug library sequential: false # whether combine all augmentation methods to a sequence. If it's True, a sample will be augmented by all opened augmentation methods sequentially. If it's False, each opened augmentation method would generate its augmented samples independently. From ede8afcbb13d8e398c0ad6d1a5840a457aea32cb Mon Sep 17 00:00:00 2001 From: Qirui-jiao <156628817+Qirui-jiao@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:38:59 +0800 Subject: [PATCH 03/10] Update mllm_mapper.py --- data_juicer/ops/mapper/mllm_mapper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/data_juicer/ops/mapper/mllm_mapper.py b/data_juicer/ops/mapper/mllm_mapper.py index 772ab52d5..52b7feb1b 100644 --- a/data_juicer/ops/mapper/mllm_mapper.py +++ b/data_juicer/ops/mapper/mllm_mapper.py @@ -30,6 +30,8 @@ def __init__(self, """ Initialization method. :param hf_model: Hugginface model id. + :param max_new_tokens: the maximum number of new tokens + generated by the model. :param sampling_params: Sampling hyperparameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} :param args: extra args From d8e97f3544a4183336a2c2162cf7a149814a671b Mon Sep 17 00:00:00 2001 From: Qirui-jiao <156628817+Qirui-jiao@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:43:04 +0800 Subject: [PATCH 04/10] Update mllm_mapper.py --- data_juicer/ops/mapper/mllm_mapper.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/data_juicer/ops/mapper/mllm_mapper.py b/data_juicer/ops/mapper/mllm_mapper.py index 52b7feb1b..772ab52d5 100644 --- a/data_juicer/ops/mapper/mllm_mapper.py +++ b/data_juicer/ops/mapper/mllm_mapper.py @@ -30,8 +30,6 @@ def __init__(self, """ Initialization method. :param hf_model: Hugginface model id. - :param max_new_tokens: the maximum number of new tokens - generated by the model. :param sampling_params: Sampling hyperparameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} :param args: extra args From 5ac238ae80815ae9601b026e1894d0354e5c01c0 Mon Sep 17 00:00:00 2001 From: Qirui-jiao <156628817+Qirui-jiao@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:45:19 +0800 Subject: [PATCH 05/10] Update mllm_mapper.py --- data_juicer/ops/mapper/mllm_mapper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/data_juicer/ops/mapper/mllm_mapper.py b/data_juicer/ops/mapper/mllm_mapper.py index 772ab52d5..52b7feb1b 100644 --- a/data_juicer/ops/mapper/mllm_mapper.py +++ b/data_juicer/ops/mapper/mllm_mapper.py @@ -30,6 +30,8 @@ def __init__(self, """ Initialization method. :param hf_model: Hugginface model id. + :param max_new_tokens: the maximum number of new tokens + generated by the model. :param sampling_params: Sampling hyperparameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} :param args: extra args From 4c5a9f30f2ef048cea1b39d8518e571ee841d989 Mon Sep 17 00:00:00 2001 From: Qirui-jiao <156628817+Qirui-jiao@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:46:43 +0800 Subject: [PATCH 06/10] Update mllm_mapper.py --- data_juicer/ops/mapper/mllm_mapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_juicer/ops/mapper/mllm_mapper.py b/data_juicer/ops/mapper/mllm_mapper.py index 52b7feb1b..f18dacee2 100644 --- a/data_juicer/ops/mapper/mllm_mapper.py +++ b/data_juicer/ops/mapper/mllm_mapper.py @@ -30,7 +30,7 @@ def __init__(self, """ Initialization method. :param hf_model: Hugginface model id. - :param max_new_tokens: the maximum number of new tokens + :param max_new_tokens: the maximum number of new tokens generated by the model. :param sampling_params: Sampling hyperparameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} From 64d81dee74c0fca80d874d1295784378b7ba87f6 Mon Sep 17 00:00:00 2001 From: Qirui-jiao Date: Sat, 24 Aug 2024 10:37:48 +0800 Subject: [PATCH 07/10] update --- configs/config_all.yaml | 6 ++- data_juicer/ops/mapper/mllm_mapper.py | 63 +++++++++++++++++++-------- tests/ops/mapper/test_mllm_mapper.py | 9 ++-- 3 files changed, 56 insertions(+), 22 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index e15fddf87..8bcb14387 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -101,7 +101,11 @@ process: blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] radius: 2 # radius of blur kernel - mllm_mapper: # use MLLMs for visual question answering tasks - sampling_params: {} # sampling hyperparameters for text generation + hf_model: 'liuhaotia/llava-v1.6-vicuna-7b' # model name of the MLLM on huggingface + max_new_tokens: 256 # the maximum number of new tokens generated by the model + temperature: 0.2 # used to control the randomness of generated text + top_p: None # randomly select the next word from the group of words whose cumulative probability reaches p + num_beams: 1 # the larger the beam search size, the higher the quality of the generated text - nlpaug_en_mapper: # simply augment texts in English based on the nlpaug library sequential: false # whether combine all augmentation methods to a sequence. If it's True, a sample will be augmented by all opened augmentation methods sequentially. If it's False, each opened augmentation method would generate its augmented samples independently. aug_num: 1 # number of augmented samples to be generated. If `sequential` is True, there will be total aug_num augmented samples generated. If it's False, there will be (aug_num * #opened_aug_method) augmented samples generated. diff --git a/data_juicer/ops/mapper/mllm_mapper.py b/data_juicer/ops/mapper/mllm_mapper.py index 772ab52d5..78ee96fa3 100644 --- a/data_juicer/ops/mapper/mllm_mapper.py +++ b/data_juicer/ops/mapper/mllm_mapper.py @@ -1,5 +1,3 @@ -from typing import Dict - import torch from data_juicer.ops.base_op import OPERATORS, Mapper @@ -24,14 +22,23 @@ class MllmMapper(Mapper): def __init__(self, hf_model: str = 'liuhaotia/llava-v1.6-vicuna-7b', max_new_tokens=256, - sampling_params: Dict = {}, + temperature=0.2, + top_p=None, + num_beams=1, *args, **kwargs): """ Initialization method. - :param hf_model: Hugginface model id. - :param sampling_params: Sampling hyperparameters for text generation. - e.g {'temperature': 0.9, 'top_p': 0.95} + :param hf_model: hugginface model id. + :param max_new_tokens: the maximum number of new tokens + generated by the model. + :param temperature: used to control the randomness of \ + generated text. The higher the temperature, the more \ + random and creative the generated text will be. + :param top_p: randomly select the next word from the group \ + of words whose cumulative probability reaches p. + :param num_beams: the larger the beam search size, the higher \ + the quality of the generated text. :param args: extra args :param kwargs: extra args """ @@ -40,20 +47,32 @@ def __init__(self, self.hf_model = hf_model self.model_key = prepare_model(model_type='huggingface', pretrained_model_name_or_path=hf_model) - self.sampling_params = sampling_params self.max_new_tokens = max_new_tokens + self.temperature = temperature + self.top_p = top_p + self.num_beams = num_beams def process(self, sample=None, rank=None): - model, processor = get_model(self.model_key, rank=rank, use_cuda=True) - # there is no image in this sample if self.image_key not in sample or not sample[self.image_key]: return sample # load images - loaded_image_key = sample[self.image_key] - image = load_image(loaded_image_key) + loaded_image_keys = sample[self.image_key] + images = {} + for loaded_image_key in loaded_image_keys: + if loaded_image_key not in images: + # avoid loading the same images + image = load_image(loaded_image_key) + images[loaded_image_key] = image + + if torch.cuda.is_available(): + model, processor = get_model(self.model_key, + rank=rank, + use_cuda=True) + else: + model, processor = get_model(self.model_key, rank=rank) conversation = [ { @@ -73,14 +92,22 @@ def process(self, sample=None, rank=None): prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) - inputs = processor(images=image, text=prompt, - return_tensors='pt').to(model.device) + sample[self.text_key] = [] + + for image_key in images: + inputs = processor(images=images[image_key], + text=prompt, + return_tensors='pt').to(model.device) + + response = model.generate(**inputs, + max_new_tokens=self.max_new_tokens, + temperature=self.temperature, + top_p=self.top_p, + num_beams=self.num_beams) - response = model.generate(**inputs, - max_new_tokens=self.max_new_tokens, - **self.sampling_params) - output = processor.decode(response.cpu()[0], skip_special_tokens=True) + output = processor.decode(response.cpu()[0], + skip_special_tokens=True) - sample[self.text_key] = output + sample[self.text_key].append(output) return sample diff --git a/tests/ops/mapper/test_mllm_mapper.py b/tests/ops/mapper/test_mllm_mapper.py index 571259c1f..2d97f9357 100644 --- a/tests/ops/mapper/test_mllm_mapper.py +++ b/tests/ops/mapper/test_mllm_mapper.py @@ -10,13 +10,16 @@ class MllmMapperTest(DataJuicerTestCaseBase): text_key = 'text' image_key = "images" - def _run_mllm(self): + def _run_mllm(self, enable_vllm=False): op = MllmMapper( - hf_model='llava-v1.6-vicuna-7b-hf' + hf_model='llava-v1.6-vicuna-7b-hf', + temperature=0.9, + top_p=0.95, + max_new_tokens=512 ) samples = [ - {self.text_key: 'Describe this image.', self.image_key: "./crayon.jpg"}, + {self.text_key: 'Describe this image.', self.image_key: ["./ipod.jpg", "./crayon.jpg"]}, ] for sample in samples: From 8d32d00135787a48ca48bededc96d62af735cc6a Mon Sep 17 00:00:00 2001 From: Qirui-jiao Date: Tue, 27 Aug 2024 12:24:21 +0800 Subject: [PATCH 08/10] update --- configs/config_all.yaml | 2 +- data_juicer/ops/mapper/mllm_mapper.py | 11 ++++------- tests/ops/mapper/test_mllm_mapper.py | 2 +- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index a3329fc00..1bb8c3fc7 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -102,7 +102,7 @@ process: blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] radius: 2 # radius of blur kernel - mllm_mapper: # use MLLMs for visual question answering tasks - hf_model: 'liuhaotia/llava-v1.6-vicuna-7b' # model name of the MLLM on huggingface + hf_model: 'liuhaotian/llava-v1.6-vicuna-7b' # model name of the MLLM on huggingface max_new_tokens: 256 # the maximum number of new tokens generated by the model temperature: 0.2 # used to control the randomness of generated text top_p: None # randomly select the next word from the group of words whose cumulative probability reaches p diff --git a/data_juicer/ops/mapper/mllm_mapper.py b/data_juicer/ops/mapper/mllm_mapper.py index 78ee96fa3..1c01c1d19 100644 --- a/data_juicer/ops/mapper/mllm_mapper.py +++ b/data_juicer/ops/mapper/mllm_mapper.py @@ -12,7 +12,7 @@ @LOADED_IMAGES.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class MllmMapper(Mapper): - """Mapper to optimize instruction. + """Mapper to use MLLMs for visual question answering tasks. Recommended model list: [ liuhaotia/llava-v1.6-vicuna-7b ] @@ -67,12 +67,9 @@ def process(self, sample=None, rank=None): image = load_image(loaded_image_key) images[loaded_image_key] = image - if torch.cuda.is_available(): - model, processor = get_model(self.model_key, - rank=rank, - use_cuda=True) - else: - model, processor = get_model(self.model_key, rank=rank) + model, processor = get_model(model_key=self.model_key, + rank=rank, + use_cuda=self.use_cuda()) conversation = [ { diff --git a/tests/ops/mapper/test_mllm_mapper.py b/tests/ops/mapper/test_mllm_mapper.py index 2d97f9357..ef300cb2d 100644 --- a/tests/ops/mapper/test_mllm_mapper.py +++ b/tests/ops/mapper/test_mllm_mapper.py @@ -12,7 +12,7 @@ class MllmMapperTest(DataJuicerTestCaseBase): def _run_mllm(self, enable_vllm=False): op = MllmMapper( - hf_model='llava-v1.6-vicuna-7b-hf', + hf_model='liuhaotian/llava-v1.6-vicuna-7b', temperature=0.9, top_p=0.95, max_new_tokens=512 From c947e335cb19d76bc63a3bf52fdf0c34bf8e3752 Mon Sep 17 00:00:00 2001 From: Qirui-jiao <156628817+Qirui-jiao@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:37:35 +0800 Subject: [PATCH 09/10] Update mllm_mapper.py --- data_juicer/ops/mapper/mllm_mapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data_juicer/ops/mapper/mllm_mapper.py b/data_juicer/ops/mapper/mllm_mapper.py index 1c01c1d19..de49f63ae 100644 --- a/data_juicer/ops/mapper/mllm_mapper.py +++ b/data_juicer/ops/mapper/mllm_mapper.py @@ -14,13 +14,13 @@ class MllmMapper(Mapper): """Mapper to use MLLMs for visual question answering tasks. Recommended model list: [ - liuhaotia/llava-v1.6-vicuna-7b + liuhaotian/llava-v1.6-vicuna-7b ] """ _accelerator = 'cuda' def __init__(self, - hf_model: str = 'liuhaotia/llava-v1.6-vicuna-7b', + hf_model: str = 'liuhaotian/llava-v1.6-vicuna-7b', max_new_tokens=256, temperature=0.2, top_p=None, From 38d0c8fd1b204ca56d35cb6898e3e02bec8c1bb6 Mon Sep 17 00:00:00 2001 From: Qirui-jiao <156628817+Qirui-jiao@users.noreply.github.com> Date: Fri, 30 Aug 2024 23:00:27 +0800 Subject: [PATCH 10/10] Update test_mllm_mapper.py --- tests/ops/mapper/test_mllm_mapper.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/ops/mapper/test_mllm_mapper.py b/tests/ops/mapper/test_mllm_mapper.py index ef300cb2d..5d6666340 100644 --- a/tests/ops/mapper/test_mllm_mapper.py +++ b/tests/ops/mapper/test_mllm_mapper.py @@ -1,10 +1,7 @@ import unittest from data_juicer.ops.mapper.mllm_mapper import MllmMapper -from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, - DataJuicerTestCaseBase) +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -# These tests have been tested locally. -@SKIPPED_TESTS.register_module() class MllmMapperTest(DataJuicerTestCaseBase): text_key = 'text' @@ -18,8 +15,13 @@ def _run_mllm(self, enable_vllm=False): max_new_tokens=512 ) + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + img2_path = os.path.join(data_path, 'img2.jpg') + img3_path = os.path.join(data_path, 'img3.jpg') + samples = [ - {self.text_key: 'Describe this image.', self.image_key: ["./ipod.jpg", "./crayon.jpg"]}, + {self.text_key: 'Describe this image.', self.image_key: [img2_path, img3_path]}, ] for sample in samples: @@ -31,4 +33,4 @@ def test_mllm(self): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()