From ad399674a5b050ce9189d80a0bcbb9dcd177723e Mon Sep 17 00:00:00 2001 From: Qirui-jiao Date: Wed, 28 Aug 2024 20:30:36 +0800 Subject: [PATCH 1/2] Upload text_pair_similarity_filter --- configs/config_all.yaml | 5 + data_juicer/ops/filter/__init__.py | 16 +-- .../ops/filter/text_pair_similarity_filter.py | 105 ++++++++++++++++++ data_juicer/utils/constant.py | 1 + docs/Operators.md | 3 +- docs/Operators_ZH.md | 3 +- .../test_text_pair_similarity_filter.py | 62 +++++++++++ 7 files changed, 186 insertions(+), 9 deletions(-) create mode 100644 data_juicer/ops/filter/text_pair_similarity_filter.py create mode 100644 tests/ops/filter/test_text_pair_similarity_filter.py diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 855d45731..c4448c200 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -372,6 +372,11 @@ process: - text_length_filter: # filter text with length out of specific range min_len: 10 # the min length of filter range max_len: 10000 # the max length of filter range + - text_pair_similarity_filter: # filter samples according to the similarity score between the text pair. + hf_clip: 'openai/clip-vit-base-patch32' # model name of the CLIP model on huggingface + min_score: 0.1 # the min similarity score of filter range + max_score: 1.0 # the max similarity score of filter range + any_or_all: "any" # keep this sample when any/all text pairs meet the filter condition - token_num_filter: # filter text with total token number out of specific range hf_tokenizer: EleutherAI/pythia-6.9b-deduped # name of used Hugging Face tokenizer min_num: 10 # the min number of filter range diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index abce40a5b..862b4e7e4 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -12,13 +12,13 @@ specified_field_filter, specified_numeric_field_filter, stopwords_filter, suffix_filter, text_action_filter, text_entity_dependency_filter, text_length_filter, - token_num_filter, video_aesthetics_filter, - video_aspect_ratio_filter, video_duration_filter, - video_frames_text_similarity_filter, video_motion_score_filter, - video_nsfw_filter, video_ocr_area_ratio_filter, - video_resolution_filter, video_tagging_from_frames_filter, - video_watermark_filter, word_repetition_filter, - words_num_filter) + text_pair_similarity_filter, token_num_filter, + video_aesthetics_filter, video_aspect_ratio_filter, + video_duration_filter, video_frames_text_similarity_filter, + video_motion_score_filter, video_nsfw_filter, + video_ocr_area_ratio_filter, video_resolution_filter, + video_tagging_from_frames_filter, video_watermark_filter, + word_repetition_filter, words_num_filter) from .alphanumeric_filter import AlphanumericFilter from .audio_duration_filter import AudioDurationFilter from .audio_nmf_snr_filter import AudioNMFSNRFilter @@ -47,6 +47,7 @@ from .text_action_filter import TextActionFilter from .text_entity_dependency_filter import TextEntityDependencyFilter from .text_length_filter import TextLengthFilter +from .text_pair_similarity_filter import TextPairSimilarityFilter from .token_num_filter import TokenNumFilter from .video_aesthetics_filter import VideoAestheticsFilter from .video_aspect_ratio_filter import VideoAspectRatioFilter @@ -104,6 +105,7 @@ 'FlaggedWordFilter', 'WordRepetitionFilter', 'VideoMotionScoreFilter', + 'TextPairSimilarityFilter' ] # yapf: enable diff --git a/data_juicer/ops/filter/text_pair_similarity_filter.py b/data_juicer/ops/filter/text_pair_similarity_filter.py new file mode 100644 index 000000000..7797dee2a --- /dev/null +++ b/data_juicer/ops/filter/text_pair_similarity_filter.py @@ -0,0 +1,105 @@ +import numpy as np +from jsonargparse.typing import ClosedUnitInterval + +from data_juicer.ops.base_op import OPERATORS, Filter +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'text_pair_similarity_filter' + +with AvailabilityChecking(['torch', 'transformers'], OP_NAME): + + import torch + import transformers # noqa: F401 + + # avoid hanging when calling clip in multiprocessing + torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +class TextPairSimilarityFilter(Filter): + """Filter to keep text pairs with similarities between texts + within a specific range.""" + + _accelerator = 'cuda' + + def __init__(self, + hf_clip='openai/clip-vit-base-patch32', + trust_remote_code=False, + min_score: ClosedUnitInterval = 0.1, + max_score: ClosedUnitInterval = 1.0, + any_or_all: str = 'any', + *args, + **kwargs): + """ + Initialization method. + + :param hf_clip: clip model name on huggingface to compute + the similarity between image and text. + :param min_score: The min similarity to keep samples. + :param max_score: The max similarity to keep samples. + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all images. 'any': keep this sample if any images meet the + condition. 'all': keep this sample only if all images meet the + condition. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_score = min_score + self.max_score = max_score + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + self.model_key = prepare_model(model_type='huggingface', + pretrained_model_name_or_path=hf_clip, + trust_remote_code=trust_remote_code) + + def compute_stats(self, sample, rank=None, context=False): + + # check if it's computed already + if StatsKeys.text_pair_similarity in sample[Fields.stats]: + return sample + + # there is no text in this sample + if (self.text_key not in sample or 'target_text' not in sample + or len(sample[self.text_key]) == 0 + or len(sample['target_text']) == 0): + sample[Fields.stats][StatsKeys.text_pair_similarity] = np.array( + [], dtype=np.float64) + return sample + + model, processor = get_model(self.model_key, rank, self.use_cuda()) + + text1 = sample[self.text_key] + text2 = sample['target_text'] + + text_tensors = processor([text1, text2], + padding=True, + return_tensors='pt').to(model.device) + text_features = model.get_text_features(**text_tensors) + + similarity = torch.cosine_similarity(text_features[0], + text_features[1], + dim=0) + sample[Fields.stats][StatsKeys.text_pair_similarity] = [similarity] + + return sample + + def process(self, sample, rank=None): + similarity = sample[Fields.stats][StatsKeys.text_pair_similarity] + if len(similarity) <= 0: + return True + + keep_bools = np.array([ + self.min_score <= sim_value <= self.max_score + for sim_value in similarity + ]) + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 13bddb687..fc8973a98 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -126,6 +126,7 @@ class StatsKeysConstant(object): special_char_ratio = 'special_char_ratio' stopwords_ratio = 'stopwords_ratio' text_len = 'text_len' + text_pair_similarity = 'text_pair_similarity' num_action = 'num_action' num_dependency_edges = 'num_dependency_edges' num_token = 'num_token' diff --git a/docs/Operators.md b/docs/Operators.md index a35210161..39aeeb02b 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -12,7 +12,7 @@ The operators in Data-Juicer are categorized into 5 types. |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | | [ Mapper ]( #mapper ) | 43 | Edits and transforms samples | -| [ Filter ]( #filter ) | 41 | Filters out low-quality samples | +| [ Filter ]( #filter ) | 42 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -127,6 +127,7 @@ All the specific operators are listed below, each featured with several capabili | text_action_filter | General | en, zh | Keeps samples containing action verbs in their texts | | text_entity_dependency_filter | General | en, zh | Keeps samples containing dependency edges for an entity in the dependency tree of the texts | | text_length_filter | General | en, zh | Keeps samples with total text length within the specified range | +| text_pair_similarity_filter | General | en, zh | Keeps text pairs with text feature cosine similarity within the specified range based on a CLIP model | | token_num_filter | General | en, zh | Keeps samples with token count within the specified range | | video_aesthetics_filter | Video | - | Keeps samples whose specified frames have aesthetics scores within the specified range | | video_aspect_ratio_filter | Video | - | Keeps samples containing videos with aspect ratios within the specified range | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 855d109a7..f6df6c909 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -12,7 +12,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 | | [ Mapper ]( #mapper ) | 43 | 对数据样本进行编辑和转换 | -| [ Filter ]( #filter ) | 41 | 过滤低质量样本 | +| [ Filter ]( #filter ) | 42 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -125,6 +125,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | text_action_filter | General | en, zh | 保留文本部分包含动作的样本 | | text_entity_dependency_filter | General | en, zh | 保留文本部分的依存树中具有非独立实体的样本 | | text_length_filter | General | en, zh | 保留总文本长度在指定范围内的样本 | +| text_pair_similarity_filter | General | en, zh | 保留文本特征余弦相似度(基于CLIP模型)在指定范围内的样本 | | token_num_filter | General | en, zh | 保留token数在指定范围内的样本 | | video_aspect_ratio_filter | Video | - | 保留包含视频的宽高比在指定范围内的样本 | | video_duration_filter | Video | - | 保留包含视频的时长在指定范围内的样本 | diff --git a/tests/ops/filter/test_text_pair_similarity_filter.py b/tests/ops/filter/test_text_pair_similarity_filter.py new file mode 100644 index 000000000..34729faa3 --- /dev/null +++ b/tests/ops/filter/test_text_pair_similarity_filter.py @@ -0,0 +1,62 @@ +import os +import unittest + +from data_juicer.core.data import NestedDataset as Dataset + +from data_juicer.ops.filter.text_pair_similarity_filter import TextPairSimilarityFilter +from data_juicer.utils.constant import Fields +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + + +class TextPairSimilarityFilterTest(DataJuicerTestCaseBase): + + hf_clip = 'openai/clip-vit-base-patch32' + + + @classmethod + def tearDownClass(cls) -> None: + super().tearDownClass(cls.hf_clip) + + def _run_filter(self, dataset: Dataset, op, num_proc=1): + + if Fields.stats not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + + dataset = dataset.map(op.compute_stats, + num_proc=num_proc, + with_rank=True) + dataset = dataset.filter(op.process, num_proc=num_proc) + dataset = dataset.select_columns(column_names=['text', 'target_text']) + res_list = dataset.to_list() + print(res_list) + + def test_no_eoc_special_token(self): + + ds_list = [{ + 'target_text': 'a lovely cat', + 'text': 'a lovely cat', + }, { + 'target_text': 'a lovely cat', + 'text': 'a cute cat', + }, { + 'target_text': 'a lovely cat', + 'text': 'a black dog', + }] + + + dataset = Dataset.from_list(ds_list) + op = TextPairSimilarityFilter(hf_clip=self.hf_clip, + any_or_all='any', + min_score=0.1, + max_score=0.85) + self._run_filter(dataset, op) + + +if __name__ == '__main__': + unittest.main() From 376936b3db6f269d2e408d8bb22161a3a2ec5f3b Mon Sep 17 00:00:00 2001 From: Qirui-jiao Date: Mon, 2 Sep 2024 21:20:55 +0800 Subject: [PATCH 2/2] update --- configs/config_all.yaml | 1 + .../ops/filter/text_pair_similarity_filter.py | 24 +++++++++++++------ .../test_text_pair_similarity_filter.py | 23 +++++++++++------- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 8a7045daa..89ce40cb7 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -404,6 +404,7 @@ process: hf_clip: 'openai/clip-vit-base-patch32' # model name of the CLIP model on huggingface min_score: 0.1 # the min similarity score of filter range max_score: 1.0 # the max similarity score of filter range + text_key_second: None # used to store the other sentence in the text pair any_or_all: "any" # keep this sample when any/all text pairs meet the filter condition - token_num_filter: # filter text with total token number out of specific range hf_tokenizer: EleutherAI/pythia-6.9b-deduped # name of used Hugging Face tokenizer diff --git a/data_juicer/ops/filter/text_pair_similarity_filter.py b/data_juicer/ops/filter/text_pair_similarity_filter.py index e8151f28d..635c4c640 100644 --- a/data_juicer/ops/filter/text_pair_similarity_filter.py +++ b/data_juicer/ops/filter/text_pair_similarity_filter.py @@ -1,3 +1,5 @@ +import logging + import numpy as np from jsonargparse.typing import ClosedUnitInterval @@ -6,6 +8,9 @@ from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.model_utils import get_model, prepare_model +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + OP_NAME = 'text_pair_similarity_filter' with AvailabilityChecking(['torch', 'transformers'], OP_NAME): @@ -29,6 +34,7 @@ def __init__(self, trust_remote_code=False, min_score: ClosedUnitInterval = 0.1, max_score: ClosedUnitInterval = 1.0, + text_key_second=None, any_or_all: str = 'any', *args, **kwargs): @@ -39,6 +45,8 @@ def __init__(self, the similarity between image and text. :param min_score: The min similarity to keep samples. :param max_score: The max similarity to keep samples. + :param text_key_second: used to store the other sentence + in the text pair. :param any_or_all: keep this sample with 'any' or 'all' strategy of all images. 'any': keep this sample if any images meet the condition. 'all': keep this sample only if all images meet the @@ -56,7 +64,7 @@ def __init__(self, self.model_key = prepare_model(model_type='huggingface', pretrained_model_name_or_path=hf_clip, trust_remote_code=trust_remote_code) - self.new_sample_key = ['target_text'] + self.text_key_second = text_key_second def compute_stats(self, sample, rank=None, context=False): @@ -65,13 +73,15 @@ def compute_stats(self, sample, rank=None, context=False): return sample # there is no target text - for temp_new_key in self.new_sample_key: - if temp_new_key not in sample or len(sample[temp_new_key]) == 0: - raise ValueError( - f'Key \'{temp_new_key}\' is not found in sample. ') + if self.text_key_second is None: + logger.error('This OP (text_pair_similarity_filter) requires \ + processing multiple fields, and you need to specify \ + valid `text_key_second`') # there is no text in this sample - if (self.text_key not in sample or len(sample[self.text_key]) == 0): + if (self.text_key not in sample or len(sample[self.text_key]) == 0 + or self.text_key_second not in sample + or len(sample[self.text_key_second]) == 0): sample[Fields.stats][StatsKeys.text_pair_similarity] = np.array( [], dtype=np.float64) return sample @@ -79,7 +89,7 @@ def compute_stats(self, sample, rank=None, context=False): model, processor = get_model(self.model_key, rank, self.use_cuda()) text1 = sample[self.text_key] - text2 = sample['target_text'] + text2 = sample[self.text_key_second] text_tensors = processor([text1, text2], padding=True, diff --git a/tests/ops/filter/test_text_pair_similarity_filter.py b/tests/ops/filter/test_text_pair_similarity_filter.py index b65849bc9..083849443 100644 --- a/tests/ops/filter/test_text_pair_similarity_filter.py +++ b/tests/ops/filter/test_text_pair_similarity_filter.py @@ -11,7 +11,10 @@ class TextPairSimilarityFilterTest(DataJuicerTestCaseBase): - hf_clip = 'openai/clip-vit-base-patch32' + hf_clip = "openai/clip-vit-base-patch32" + + text_key = "text" + text_key_second = "target_text" @classmethod @@ -31,21 +34,22 @@ def _run_filter(self, dataset: Dataset, op, num_proc=1): num_proc=num_proc, with_rank=True) dataset = dataset.filter(op.process, num_proc=num_proc) - dataset = dataset.select_columns(column_names=['text', 'target_text']) + dataset = dataset.select_columns(column_names=[self.text_key, + self.text_key_second]) res_list = dataset.to_list() print(res_list) def test_no_eoc_special_token(self): ds_list = [{ - 'target_text': 'a lovely cat', - 'text': 'a lovely cat', + self.text_key_second: 'a lovely cat', + self.text_key: 'a lovely cat', }, { - 'target_text': 'a lovely cat', - 'text': 'a cute cat', + self.text_key_second: 'a lovely cat', + self.text_key: 'a cute cat', }, { - 'target_text': 'a lovely cat', - 'text': 'a black dog', + self.text_key_second: 'a lovely cat', + self.text_key: 'a black dog', }] @@ -53,7 +57,8 @@ def test_no_eoc_special_token(self): op = TextPairSimilarityFilter(hf_clip=self.hf_clip, any_or_all='any', min_score=0.1, - max_score=0.85) + max_score=0.99, + text_key_second=self.text_key_second) self._run_filter(dataset, op)