diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 9cb64fa30..89ce40cb7 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -400,6 +400,12 @@ process: - text_length_filter: # filter text with length out of specific range min_len: 10 # the min length of filter range max_len: 10000 # the max length of filter range + - text_pair_similarity_filter: # filter samples according to the similarity score between the text pair. + hf_clip: 'openai/clip-vit-base-patch32' # model name of the CLIP model on huggingface + min_score: 0.1 # the min similarity score of filter range + max_score: 1.0 # the max similarity score of filter range + text_key_second: None # used to store the other sentence in the text pair + any_or_all: "any" # keep this sample when any/all text pairs meet the filter condition - token_num_filter: # filter text with total token number out of specific range hf_tokenizer: EleutherAI/pythia-6.9b-deduped # name of used Hugging Face tokenizer min_num: 10 # the min number of filter range diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index abce40a5b..862b4e7e4 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -12,13 +12,13 @@ specified_field_filter, specified_numeric_field_filter, stopwords_filter, suffix_filter, text_action_filter, text_entity_dependency_filter, text_length_filter, - token_num_filter, video_aesthetics_filter, - video_aspect_ratio_filter, video_duration_filter, - video_frames_text_similarity_filter, video_motion_score_filter, - video_nsfw_filter, video_ocr_area_ratio_filter, - video_resolution_filter, video_tagging_from_frames_filter, - video_watermark_filter, word_repetition_filter, - words_num_filter) + text_pair_similarity_filter, token_num_filter, + video_aesthetics_filter, video_aspect_ratio_filter, + video_duration_filter, video_frames_text_similarity_filter, + video_motion_score_filter, video_nsfw_filter, + video_ocr_area_ratio_filter, video_resolution_filter, + video_tagging_from_frames_filter, video_watermark_filter, + word_repetition_filter, words_num_filter) from .alphanumeric_filter import AlphanumericFilter from .audio_duration_filter import AudioDurationFilter from .audio_nmf_snr_filter import AudioNMFSNRFilter @@ -47,6 +47,7 @@ from .text_action_filter import TextActionFilter from .text_entity_dependency_filter import TextEntityDependencyFilter from .text_length_filter import TextLengthFilter +from .text_pair_similarity_filter import TextPairSimilarityFilter from .token_num_filter import TokenNumFilter from .video_aesthetics_filter import VideoAestheticsFilter from .video_aspect_ratio_filter import VideoAspectRatioFilter @@ -104,6 +105,7 @@ 'FlaggedWordFilter', 'WordRepetitionFilter', 'VideoMotionScoreFilter', + 'TextPairSimilarityFilter' ] # yapf: enable diff --git a/data_juicer/ops/filter/text_pair_similarity_filter.py b/data_juicer/ops/filter/text_pair_similarity_filter.py new file mode 100644 index 000000000..635c4c640 --- /dev/null +++ b/data_juicer/ops/filter/text_pair_similarity_filter.py @@ -0,0 +1,120 @@ +import logging + +import numpy as np +from jsonargparse.typing import ClosedUnitInterval + +from data_juicer.ops.base_op import OPERATORS, Filter +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +OP_NAME = 'text_pair_similarity_filter' + +with AvailabilityChecking(['torch', 'transformers'], OP_NAME): + + import torch + import transformers # noqa: F401 + + # avoid hanging when calling clip in multiprocessing + torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +class TextPairSimilarityFilter(Filter): + """Filter to keep text pairs with similarities between texts + within a specific range.""" + + _accelerator = 'cuda' + + def __init__(self, + hf_clip='openai/clip-vit-base-patch32', + trust_remote_code=False, + min_score: ClosedUnitInterval = 0.1, + max_score: ClosedUnitInterval = 1.0, + text_key_second=None, + any_or_all: str = 'any', + *args, + **kwargs): + """ + Initialization method. + + :param hf_clip: clip model name on huggingface to compute + the similarity between image and text. + :param min_score: The min similarity to keep samples. + :param max_score: The max similarity to keep samples. + :param text_key_second: used to store the other sentence + in the text pair. + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all images. 'any': keep this sample if any images meet the + condition. 'all': keep this sample only if all images meet the + condition. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_score = min_score + self.max_score = max_score + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + self.model_key = prepare_model(model_type='huggingface', + pretrained_model_name_or_path=hf_clip, + trust_remote_code=trust_remote_code) + self.text_key_second = text_key_second + + def compute_stats(self, sample, rank=None, context=False): + + # check if it's computed already + if StatsKeys.text_pair_similarity in sample[Fields.stats]: + return sample + + # there is no target text + if self.text_key_second is None: + logger.error('This OP (text_pair_similarity_filter) requires \ + processing multiple fields, and you need to specify \ + valid `text_key_second`') + + # there is no text in this sample + if (self.text_key not in sample or len(sample[self.text_key]) == 0 + or self.text_key_second not in sample + or len(sample[self.text_key_second]) == 0): + sample[Fields.stats][StatsKeys.text_pair_similarity] = np.array( + [], dtype=np.float64) + return sample + + model, processor = get_model(self.model_key, rank, self.use_cuda()) + + text1 = sample[self.text_key] + text2 = sample[self.text_key_second] + + text_tensors = processor([text1, text2], + padding=True, + return_tensors='pt').to(model.device) + text_features = model.get_text_features(**text_tensors) + + similarity = torch.cosine_similarity(text_features[0], + text_features[1], + dim=0) + sample[Fields.stats][StatsKeys.text_pair_similarity] = [similarity] + + return sample + + def process(self, sample, rank=None): + similarity = sample[Fields.stats][StatsKeys.text_pair_similarity] + if len(similarity) <= 0: + return True + + keep_bools = np.array([ + self.min_score <= sim_value <= self.max_score + for sim_value in similarity + ]) + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 13bddb687..fc8973a98 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -126,6 +126,7 @@ class StatsKeysConstant(object): special_char_ratio = 'special_char_ratio' stopwords_ratio = 'stopwords_ratio' text_len = 'text_len' + text_pair_similarity = 'text_pair_similarity' num_action = 'num_action' num_dependency_edges = 'num_dependency_edges' num_token = 'num_token' diff --git a/docs/Operators.md b/docs/Operators.md index 144550790..bb7fd4bc1 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -12,7 +12,7 @@ The operators in Data-Juicer are categorized into 5 types. |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | | [ Mapper ]( #mapper ) | 46 | Edits and transforms samples | -| [ Filter ]( #filter ) | 41 | Filters out low-quality samples | +| [ Filter ]( #filter ) | 42 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -130,6 +130,7 @@ All the specific operators are listed below, each featured with several capabili | text_action_filter | General | en, zh | Keeps samples containing action verbs in their texts | | text_entity_dependency_filter | General | en, zh | Keeps samples containing dependency edges for an entity in the dependency tree of the texts | | text_length_filter | General | en, zh | Keeps samples with total text length within the specified range | +| text_pair_similarity_filter | General | en, zh | Keeps text pairs with text feature cosine similarity within the specified range based on a CLIP model | | token_num_filter | General | en, zh | Keeps samples with token count within the specified range | | video_aesthetics_filter | Video | - | Keeps samples whose specified frames have aesthetics scores within the specified range | | video_aspect_ratio_filter | Video | - | Keeps samples containing videos with aspect ratios within the specified range | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 3d0e33df3..f1ee2b73a 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -12,7 +12,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 | | [ Mapper ]( #mapper ) | 46 | 对数据样本进行编辑和转换 | -| [ Filter ]( #filter ) | 41 | 过滤低质量样本 | +| [ Filter ]( #filter ) | 42 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -128,6 +128,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | text_action_filter | General | en, zh | 保留文本部分包含动作的样本 | | text_entity_dependency_filter | General | en, zh | 保留文本部分的依存树中具有非独立实体的样本 | | text_length_filter | General | en, zh | 保留总文本长度在指定范围内的样本 | +| text_pair_similarity_filter | General | en, zh | 保留文本特征余弦相似度(基于CLIP模型)在指定范围内的样本 | | token_num_filter | General | en, zh | 保留token数在指定范围内的样本 | | video_aspect_ratio_filter | Video | - | 保留包含视频的宽高比在指定范围内的样本 | | video_duration_filter | Video | - | 保留包含视频的时长在指定范围内的样本 | diff --git a/tests/ops/filter/test_text_pair_similarity_filter.py b/tests/ops/filter/test_text_pair_similarity_filter.py new file mode 100644 index 000000000..083849443 --- /dev/null +++ b/tests/ops/filter/test_text_pair_similarity_filter.py @@ -0,0 +1,66 @@ +import os +import unittest + +from data_juicer.core.data import NestedDataset as Dataset + +from data_juicer.ops.filter.text_pair_similarity_filter import TextPairSimilarityFilter +from data_juicer.utils.constant import Fields +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class TextPairSimilarityFilterTest(DataJuicerTestCaseBase): + + hf_clip = "openai/clip-vit-base-patch32" + + text_key = "text" + text_key_second = "target_text" + + + @classmethod + def tearDownClass(cls) -> None: + super().tearDownClass(cls.hf_clip) + + def _run_filter(self, dataset: Dataset, op, num_proc=1): + + if Fields.stats not in dataset.features: + # TODO: + # this is a temp solution, + # only add stats when calling filter op + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + + dataset = dataset.map(op.compute_stats, + num_proc=num_proc, + with_rank=True) + dataset = dataset.filter(op.process, num_proc=num_proc) + dataset = dataset.select_columns(column_names=[self.text_key, + self.text_key_second]) + res_list = dataset.to_list() + print(res_list) + + def test_no_eoc_special_token(self): + + ds_list = [{ + self.text_key_second: 'a lovely cat', + self.text_key: 'a lovely cat', + }, { + self.text_key_second: 'a lovely cat', + self.text_key: 'a cute cat', + }, { + self.text_key_second: 'a lovely cat', + self.text_key: 'a black dog', + }] + + + dataset = Dataset.from_list(ds_list) + op = TextPairSimilarityFilter(hf_clip=self.hf_clip, + any_or_all='any', + min_score=0.1, + max_score=0.99, + text_key_second=self.text_key_second) + self._run_filter(dataset, op) + + +if __name__ == '__main__': + unittest.main()