From 4933f5d0f6e9cce9f8bbb7cf82879bee05cb6fd8 Mon Sep 17 00:00:00 2001 From: panxuchen Date: Fri, 15 Nov 2024 15:30:08 +0800 Subject: [PATCH 1/6] add ray minhash deduplicator --- configs/config_all.yaml | 11 + data_juicer/core/ray_data.py | 4 +- data_juicer/ops/deduplicator/__init__.py | 3 +- .../ray_redis_minhash_deduplicator.py | 365 ++++++++++++++++++ data_juicer/utils/constant.py | 1 + docs/Operators.md | 1 + docs/Operators_ZH.md | 1 + environments/dist_requires.txt | 2 +- 8 files changed, 385 insertions(+), 3 deletions(-) create mode 100644 data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 90fc18875..d251d24a2 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -643,6 +643,17 @@ process: redis_port: 6380 # the port of redis instance, please note that the default port of redis is 6379 which is the same as default port for ray, so we need to modify the default redis config to use it in other port lowercase: false # whether to convert text to lower case ignore_non_character: false # whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations + - ray_redis_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm + redis_address: 'redis://localhost:6379' # the address of the redis instance + tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece] + window_size: 5 # window size of shingling + num_permutations: 256 # number of permutations in minhash computing + jaccard_threshold: 0.7 # the min jaccard similarity threshold in near-duplicate detection. When the jaccard similarity of two sample texts is >= this threshold, they are regarded as similar samples and this op will only keep one of them after deduplication + num_bands: null # number of bands in LSH. Default it's None, and it will be determined by an optimal params computation algorithm by minimize the weighted sum of probs of False Positives and False Negatives + num_rows_per_band: null # number of rows in each band in LSH. Default it's None, and it will be determined by an optimal params computation algorithm + lowercase: true # whether to convert text to lower case + ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash. + tokenizer_model: null # path for the sentencepiece model, used for sentencepiece tokenization. # Selector ops - frequency_specified_field_selector: # selector to select samples based on the sorted frequency of specified field value diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 0c131561e..30848bcf3 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -5,7 +5,7 @@ from data_juicer import cuda_device_count from data_juicer.core.data import DJDataset -from data_juicer.ops import Filter, Mapper +from data_juicer.ops import Deduplicator, Filter, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.process_utils import calculate_np @@ -123,6 +123,8 @@ def _run_single_op(self, op): self.data.write_json(op.stats_export_path, force_ascii=False) self.data = self.data.filter(op.process) + elif isinstance(op, Deduplicator): + self.data = op.run(self.data) else: logger.error( 'Ray executor only support Filter and Mapper OPs for now') diff --git a/data_juicer/ops/deduplicator/__init__.py b/data_juicer/ops/deduplicator/__init__.py index 56aec0e10..c368e196d 100644 --- a/data_juicer/ops/deduplicator/__init__.py +++ b/data_juicer/ops/deduplicator/__init__.py @@ -5,6 +5,7 @@ from .ray_basic_deduplicator import RayBasicDeduplicator from .ray_document_deduplicator import RayDocumentDeduplicator from .ray_image_deduplicator import RayImageDeduplicator +from .ray_redis_minhash_deduplicator import RayRedisMinhashDeduplicator from .ray_video_deduplicator import RayVideoDeduplicator from .video_deduplicator import VideoDeduplicator @@ -12,5 +13,5 @@ 'DocumentDeduplicator', 'DocumentMinhashDeduplicator', 'DocumentSimhashDeduplicator', 'ImageDeduplicator', 'RayBasicDeduplicator', 'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator', - 'VideoDeduplicator' + 'RayRedisMinhashDeduplicator', 'VideoDeduplicator' ] diff --git a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py new file mode 100644 index 000000000..8883feb4b --- /dev/null +++ b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py @@ -0,0 +1,365 @@ +import random +import time +from collections import defaultdict +from typing import Optional + +import numpy as np +import pandas as pd +import pyarrow as pa +import regex +from loguru import logger +from pydantic import Field, PositiveInt +from typing_extensions import Annotated + +from data_juicer.utils.constant import HashKeys +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import prepare_sentencepiece_model + +from ..base_op import OPERATORS, Deduplicator +from ..common.helper_func import split_on_whitespace +from .document_minhash_deduplicator import (MAX_HASH, MERSENNE_PRIME, + optimal_param, sha1_hash32) + +redis = LazyLoader('redis', 'redis') + + +def retry_on_busy(func): + + def wrapper(*args, **kwargs): + max_retries = 10 + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except Exception as e: + if 'BUSY' in str(e) and attempt < max_retries - 1: + time.sleep(random.uniform(0.1, 0.3) * (2**attempt)) + else: + raise + + return wrapper + + +class RedisUnionFind: + + def __init__(self, + prefix: str, + redis_address: str = 'redis://localhost:6379'): + self.prefix = prefix + self.redis_address = redis_address + self.redis = redis.from_url(url=redis_address) + self.set_key = f'{prefix}_UF_SET' + self.rank_key = f'{prefix}_UF_RANK' + self.incur_id_key = f'{prefix}_UF_INCURID' + + # Lua scripts + self.union_script = self.redis.register_script(""" + local function find(x) + local path = {} + while true do + local parent = redis.call('HGET', KEYS[1], x) + if not parent then + return nil + end + if parent == x then + break + end + table.insert(path, x) + x = parent + end + for _, node in ipairs(path) do + redis.call('HSET', KEYS[1], node, x) + end + return x + end + + local root_x = find(ARGV[1]) + local root_y = find(ARGV[2]) + if not root_x then + redis.call('HSET', KEYS[1], ARGV[1], ARGV[1]) + redis.call('HSET', KEYS[2], ARGV[1], 0) + root_x = ARGV[1] + end + if not root_y then + redis.call('HSET', KEYS[1], ARGV[2], ARGV[2]) + redis.call('HSET', KEYS[2], ARGV[2], 0) + root_y = ARGV[2] + end + if root_x == root_y then + return root_x + end + local rank_x = tonumber(redis.call('HGET', KEYS[2], root_x)) + local rank_y = tonumber(redis.call('HGET', KEYS[2], root_y)) + if rank_x < rank_y then + redis.call('HSET', KEYS[1], root_x, root_y) + return root_y + elseif rank_x > rank_y then + redis.call('HSET', KEYS[1], root_y, root_x) + return root_x + else + redis.call('HSET', KEYS[1], root_y, root_x) + redis.call('HINCRBY', KEYS[2], root_x, 1) + return root_x + end + """) + + def get_uid(self): + return int(self.redis.incr(self.incur_id_key)) + + @retry_on_busy + def union(self, x, y): + return self.union_script(keys=[self.set_key, self.rank_key], + args=[x, y]) + + def is_ancestor(self, x): + ancestor = self.redis.hget(self.set_key, x) + return ancestor is None or int(ancestor) == x + + def __reduce__(self): + return (RedisUnionFind, (self.prefix, self.redis_address)) + + def clean(self): + self.redis.delete(self.set_key, self.rank_key, self.incur_id_key) + + +OP_NAME = 'ray_redis_minhash_deduplicator' + + +@OPERATORS.register_module(OP_NAME) +class RayRedisMinhashDeduplicator(Deduplicator): + """ + A basic exact matching deduplicator for RAY. + Although its functionality is deduplication, + it is implemented as Filter sub-class. + """ + + def __init__( + self, + tokenization: str = 'space', + window_size: PositiveInt = 5, + lowercase: bool = True, + ignore_pattern: Optional[str] = None, + num_permutations: PositiveInt = 256, + jaccard_threshold: Annotated[float, Field(ge=0, le=1)] = 0.7, + num_bands: Optional[PositiveInt] = None, + num_rows_per_band: Optional[PositiveInt] = None, + tokenizer_model: Optional[str] = None, + redis_address: str = 'redis://localhost:6379', + *args, + **kwargs, + ): + """ + Initialization method. + + :param tokenization: tokenization method for sample texts. It + should be one of [space, punctuation, character, + sentencepiece]. For English-like languages, we recommend + to use 'space', for Chinese-like languages, we recommend + to use 'character', and for multiple languages, we recommend + to use 'sentencepiece'. If using 'sentencepiece', please + provided the model path in the 'tokenizer_model' field. + :param window_size: window size of shingling + :param lowercase: whether to convert text to lower case first + :param ignore_pattern: whether to ignore sub-strings with + specific pattern when computing minhash + :param num_permutations: number of permutations in minhash + computing + :param jaccard_threshold: the min jaccard similarity threshold + in near-duplicate detection. When the jaccard similarity of + two sample texts is >= this threshold, they are regarded as + similar samples and this op will only keep one of them after + deduplication + :param num_bands: number of bands in LSH. Default it's None, and + it will be determined by an optimal params computation + algorithm by minimize the weighted sum of probs of False + Positives and False Negatives + :param num_rows_per_band: number of rows in each band in LSH. + Default it's None, and it will be determined by an optimal + params computation algorithm + :param tokenizer_model: path for the sentencepiece model, used for + sentencepiece tokenization. + :param redis_address: address of your redis instance, e.g. + 'redis://localhost:6379' + """ + super().__init__(*args, **kwargs) + # about minhash computation + self.tokenization = tokenization + self.window_size = window_size + self.lowercase = lowercase + self.ignore_pattern = ignore_pattern + if self.ignore_pattern: + self.ignore_pattern = regex.compile(self.ignore_pattern) + + # check parameters + if self.ignore_pattern and self.tokenization == 'punctuation': + logger.warning('Be careful that tokenization with punctuations ' + 'won\'t work if the ignore pattern includes ' + 'punctuations.') + self.punctuation_pattern = regex.compile(r'\p{P}') + + if self.tokenization == 'sentencepiece': + if tokenizer_model is None: + raise ValueError("To use 'sentencepiece' tokenization, " + "'tokenizer_model' is required.") + self.tokenizer = prepare_sentencepiece_model(tokenizer_model) + else: + self.tokenizer = None + + # about deduplication + self.num_permutation = num_permutations + self.jaccard_threshold = jaccard_threshold + self.num_bands = num_bands + self.num_rows_per_band = num_rows_per_band + + # initialize deduplication parameters + # check number of bands and rows + if self.num_bands is None or self.num_rows_per_band is None: + self.num_bands, self.num_rows_per_band = optimal_param( + self.jaccard_threshold, + self.num_permutation, + ) + + # compute hash ranges and create hash tables + self.hash_ranges = [(i * self.num_rows_per_band, + (i + 1) * self.num_rows_per_band) + for i in range(self.num_bands)] + self.hash_tables = [defaultdict(set) for _ in range(self.num_bands)] + + # generate permutations + gen = np.random.RandomState(seed=42) + self.perm_a, self.perm_b = np.array( + [( + gen.randint(1, MERSENNE_PRIME, dtype=np.uint64), + gen.randint(0, MERSENNE_PRIME, dtype=np.uint64), + ) for _ in range(self.num_permutation)], + dtype=np.uint64, + ).T + + def run(self, dataset): + from ray.data.aggregate import AggregateFn + + union_find = RedisUnionFind(self.redis_address) + + def add_uid_column(table: pa.Table) -> pa.Table: + new_column_data = [union_find.get_uid() for _ in range(len(table))] + new_table = table.append_column(HashKeys.uid, [new_column_data]) + return new_table + + def calculate_minhash(table: pa.Table) -> pa.Table: + ids = table.column(HashKeys.uid).to_pandas() + texts = table.column(self.text_key).to_pandas() + hashes = texts.apply(lambda x: self.compute_minhash(x)) + hashes = pa.Array.from_pandas(hashes).flatten() + + repeated_ids = pa.Array.from_pandas(ids.repeat(self.num_bands)) + + return pa.Table.from_arrays([repeated_ids, hashes], + names=[HashKeys.uid, HashKeys.minhash]) + + def _is_null(r): + return pd.isnull(r) + + class UnionFn(AggregateFn): + + def __init__(self, union_find): + union_find = union_find + + def accumulate(cur, row): + if _is_null(row): + return cur + elif _is_null(cur): + return row[HashKeys.uid] + else: + root = union_find.union(row[HashKeys.uid], cur) + return int(root) + + def merge(a, b): + if _is_null(a): + return b + if _is_null(b): + return a + root = union_find.union(a, b) + return int(root) + + super().__init__( + init=lambda k: None, + accumulate_row=accumulate, + merge=merge, + name='union', + ) + + def filter_with_union_find(table: pa.Table) -> pa.Table: + uids = table.column(HashKeys.uid).to_pandas() + mask = pa.Array.from_pandas( + uids.apply(lambda x: union_find.is_ancestor(x))) + return table.filter(mask) + + dataset_with_id = dataset.map_batches( + add_uid_column, batch_format='pyarrow').materialize() + dataset_with_id.map_batches(calculate_minhash, + batch_format='pyarrow').groupby( + HashKeys.minhash).aggregate( + UnionFn(union_find)).materialize() + result = dataset_with_id.map_batches(filter_with_union_find, + batch_format='pyarrow') + logger.info(f'Keep {result.count()} samples after MinHash dedup.') + union_find.clean() + return result + + def compute_minhash(self, text): + """ + Compute minhash values for the sample. + + :param sample: input sample + :return: sample with minhash value. + """ + if self.lowercase: + text = text.lower() + if self.ignore_pattern: + text = self.ignore_pattern.sub('', text) + + # get tokens for different tokenization method + tokens = set() + if self.tokenization == 'character': + tokens = { + str.encode(text[i:i + self.window_size]) + for i in range(len(text) - self.window_size) + } + elif self.tokenization == 'punctuation': + tokens = self.punctuation_pattern.split(text) + tokens = { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'space': + tokens = split_on_whitespace(text) + tokens = { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'sentencepiece': + tokens = self.tokenizer.encode(text, out_type=str) + tokens = { + str.encode(''.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + else: + raise NotImplementedError( + f'Unimplemented tokenization method [{self.tokenization}]') + + # compute minhash value + hv = np.array([sha1_hash32(token) for token in tokens], + dtype=np.uint64) + phv = np.bitwise_and( + ((hv * np.tile(self.perm_a, + (len(hv), 1)).T).T + self.perm_b) % MERSENNE_PRIME, + MAX_HASH) + hash_values = np.vstack([ + phv, + np.ones(self.num_permutation, dtype=np.uint64) * MAX_HASH + ]).min(axis=0) + return [ + bytes(hash_values[start:end].byteswap().data) + + start.to_bytes(8, byteorder='little') + for start, end in self.hash_ranges + # groupby minhash||brand_id + ] diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index ab88035b9..1fe8d7002 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -216,6 +216,7 @@ class StatsKeys(object, metaclass=StatsKeysMeta): class HashKeys(object): + uid = DEFAULT_PREFIX + 'uid' hash = DEFAULT_PREFIX + 'hash' minhash = DEFAULT_PREFIX + 'minhash' simhash = DEFAULT_PREFIX + 'simhash' diff --git a/docs/Operators.md b/docs/Operators.md index 7717ba434..f1a20c9ef 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -173,6 +173,7 @@ All the specific operators are listed below, each featured with several capabili | document_simhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level using SimHash | [code](../data_juicer/ops/deduplicator/document_simhash_deduplicator.py) | [tests](../tests/ops/deduplicator/test_document_simhash_deduplicator.py) | | image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | Deduplicates samples at document-level using exact matching of images between documents | [code](../data_juicer/ops/deduplicator/image_deduplicator.py) | [tests](../tests/ops/deduplicator/test_image_deduplicator.py) | | video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Deduplicates samples at document-level using exact matching of videos between documents | [code](../data_juicer/ops/deduplicator/video_deduplicator.py) | [tests](../tests/ops/deduplicator/test_video_deduplicator.py) | +| ray_redis_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level using MinHashLSH based on Ray and Redis | [code](../data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py) | - | | ray_document_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level by comparing MD5 hash on ray | [code](../data_juicer/ops/deduplicator/ray_document_deduplicator.py) | - | | ray_image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | Deduplicates samples at document-level using exact matching of images between documents on ray | [code](../data_juicer/ops/deduplicator/ray_image_deduplicator.py) | - | | ray_video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Deduplicates samples at document-level using exact matching of videos between documents on ray | [code](../data_juicer/ops/deduplicator/ray_video_deduplicator.py) | - | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 81aee2149..b1194f250 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -172,6 +172,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | document_simhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用 SimHash 在文档级别对样本去重 | [code](../data_juicer/ops/deduplicator/document_simhash_deduplicator.py) | [tests](../tests/ops/deduplicator/test_document_simhash_deduplicator.py) | | image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 使用文档之间图像的精确匹配在文档级别删除重复样本 | [code](../data_juicer/ops/deduplicator/image_deduplicator.py) | [tests](../tests/ops/deduplicator/test_image_deduplicator.py) | | video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 使用文档之间视频的精确匹配在文档级别删除重复样本 | [code](../data_juicer/ops/deduplicator/video_deduplicator.py) | [tests](../tests/ops/deduplicator/test_video_deduplicator.py) | +| ray_redis_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用 MinHashLSH 在文档级别对样本去重,面向 RAY 分布式模式 | [code](../data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py) | - | | ray_document_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较 MD5 哈希值在文档级别对样本去重,面向RAY分布式模式 | [code](../data_juicer/ops/deduplicator/ray_document_deduplicator.py) | - | | ray_image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 使用文档之间图像的精确匹配在文档级别删除重复样本,面向RAY分布式模式 | [code](../data_juicer/ops/deduplicator/ray_image_deduplicator.py) | - | | ray_video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 使用文档之间视频的精确匹配在文档级别删除重复样本,面向RAY分布式模式 | [code](../data_juicer/ops/deduplicator/ray_video_deduplicator.py) | - | diff --git a/environments/dist_requires.txt b/environments/dist_requires.txt index 4060a654f..b6ab28d06 100644 --- a/environments/dist_requires.txt +++ b/environments/dist_requires.txt @@ -1,2 +1,2 @@ -ray==2.31.0 +ray<=2.38.0 redis>=5.0.0 From 6b79f9004b82fe43c231e97d1d9dc4f2c00b6182 Mon Sep 17 00:00:00 2001 From: panxuchen Date: Fri, 15 Nov 2024 16:24:03 +0800 Subject: [PATCH 2/6] fix redis prefix --- .../ops/deduplicator/ray_redis_minhash_deduplicator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py index 8883feb4b..4a414fa00 100644 --- a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py @@ -1,5 +1,6 @@ import random import time +import uuid from collections import defaultdict from typing import Optional @@ -233,11 +234,12 @@ def __init__( ) for _ in range(self.num_permutation)], dtype=np.uint64, ).T + self.redis_address = redis_address def run(self, dataset): from ray.data.aggregate import AggregateFn - union_find = RedisUnionFind(self.redis_address) + union_find = RedisUnionFind(prefix=uuid.uuid4().hex[:8], redis_address=self.redis_address) def add_uid_column(table: pa.Table) -> pa.Table: new_column_data = [union_find.get_uid() for _ in range(len(table))] From 991e2906615c8fdd67179f41dd5ae9b808ba23e6 Mon Sep 17 00:00:00 2001 From: panxuchen Date: Fri, 15 Nov 2024 16:28:18 +0800 Subject: [PATCH 3/6] fix redis prefix --- data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py index 4a414fa00..72c250af1 100644 --- a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py @@ -239,7 +239,8 @@ def __init__( def run(self, dataset): from ray.data.aggregate import AggregateFn - union_find = RedisUnionFind(prefix=uuid.uuid4().hex[:8], redis_address=self.redis_address) + union_find = RedisUnionFind(prefix=uuid.uuid4().hex[:8], + redis_address=self.redis_address) def add_uid_column(table: pa.Table) -> pa.Table: new_column_data = [union_find.get_uid() for _ in range(len(table))] From 58e357f2ef8e5982c7408ab2ace822b923de03e3 Mon Sep 17 00:00:00 2001 From: panxuchen Date: Mon, 18 Nov 2024 17:50:42 +0800 Subject: [PATCH 4/6] fix output bug --- .../ops/deduplicator/ray_redis_minhash_deduplicator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py index 72c250af1..203fcf059 100644 --- a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py @@ -302,8 +302,8 @@ def filter_with_union_find(table: pa.Table) -> pa.Table: batch_format='pyarrow').groupby( HashKeys.minhash).aggregate( UnionFn(union_find)).materialize() - result = dataset_with_id.map_batches(filter_with_union_find, - batch_format='pyarrow') + result = dataset_with_id.map_batches( + filter_with_union_find, batch_format='pyarrow').materialize() logger.info(f'Keep {result.count()} samples after MinHash dedup.') union_find.clean() return result From e1b76f587e0399090daac5cd56e4d3d7a8669274 Mon Sep 17 00:00:00 2001 From: panxuchen Date: Wed, 20 Nov 2024 14:06:39 +0800 Subject: [PATCH 5/6] fix comments --- configs/config_all.yaml | 11 ++++------- .../ops/deduplicator/ray_basic_deduplicator.py | 13 +++++-------- .../ops/deduplicator/ray_document_deduplicator.py | 9 +++------ .../ops/deduplicator/ray_image_deduplicator.py | 9 +++------ .../deduplicator/ray_redis_minhash_deduplicator.py | 4 +--- .../ops/deduplicator/ray_video_deduplicator.py | 9 +++------ docs/Operators.md | 2 +- docs/Operators_ZH.md | 2 +- 8 files changed, 21 insertions(+), 38 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index d251d24a2..46c3c502e 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -632,19 +632,16 @@ process: - video_deduplicator: # deduplicator to deduplicate samples at document-level using exact matching of videos between documents. consider_text: false # whether to consider text hash together with video hash when applying deduplication. - ray_video_deduplicator: # the simple video deduplicator that can run on multi-nodes using md5 hashing exact matching method - redis_host: 'redis_host' # the host of the redis instance - redis_port: 6380 # the port of redis instance, please note that the default port of redis is 6379 which is the same as default port for ray, so we need to modify the default redis config to use it in other port + redis_address: 'redis://localhost:6379' # the address of the redis instance - ray_image_deduplicator: # the simple image deduplicator that can deduplicate samples at document-level using exact matching of images between documents. - redis_host: 'redis_host' # the host of the redis instance - redis_port: 6380 # the port of redis instance, please note that the default port of redis is 6379 which is the same as default port for ray, so we need to modify the default redis config to use it in other port + redis_address: 'redis://localhost:6379' # the address of the redis instance method: phash # hash method for image. One of [phash, dhash, whash, ahash] - ray_document_deduplicator: # the simple document deduplicator that can run on multi-nodes using md5 hashing exact matching method - redis_host: 'redis_host' # the host of the redis instance - redis_port: 6380 # the port of redis instance, please note that the default port of redis is 6379 which is the same as default port for ray, so we need to modify the default redis config to use it in other port + redis_address: 'redis://localhost:6379' # the address of the redis instance lowercase: false # whether to convert text to lower case ignore_non_character: false # whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations - ray_redis_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm - redis_address: 'redis://localhost:6379' # the address of the redis instance + redis_address: 'redis://localhost:6379' # the address of the redis instance tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece] window_size: 5 # window size of shingling num_permutations: 256 # number of permutations in minhash computing diff --git a/data_juicer/ops/deduplicator/ray_basic_deduplicator.py b/data_juicer/ops/deduplicator/ray_basic_deduplicator.py index dad317d17..3fb902386 100644 --- a/data_juicer/ops/deduplicator/ray_basic_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_basic_deduplicator.py @@ -19,23 +19,20 @@ class RayBasicDeduplicator(Filter): EMPTY_HASH_VALUE = 'EMPTY' def __init__(self, - redis_host: str = 'localhost', - redis_port: PositiveInt = 6380, + redis_address: str = 'redis://localhost:6379', *args, **kwargs): """ Initialization. - :param redis_host: the hostname of redis server - :param redis_port: the port of redis server + :param redis_address: the address of redis server :param args: extra args :param kwargs: extra args """ super().__init__(*args, **kwargs) - self.redis_host = redis_host - self.redis_port = redis_port + self.redis_address = redis_address # TODO: add a barrier to ensure that flushdb is performed before # the operator is called - r = redis.StrictRedis(host=self.redis_host, port=self.redis_port, db=0) + r = redis.from_url(url=redis_address) r.flushdb(0) def calculate_hash(self, sample, context=False): @@ -44,7 +41,7 @@ def calculate_hash(self, sample, context=False): def compute_stats_single(self, sample, context=False): # init redis client - r = redis.StrictRedis(host=self.redis_host, port=self.redis_port, db=0) + r = redis.from_url(url=self.redis_address) # compute hash md5_value = self.calculate_hash(sample, context) # check existing diff --git a/data_juicer/ops/deduplicator/ray_document_deduplicator.py b/data_juicer/ops/deduplicator/ray_document_deduplicator.py index ce5cced4e..667f86e38 100644 --- a/data_juicer/ops/deduplicator/ray_document_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_document_deduplicator.py @@ -17,24 +17,21 @@ class RayDocumentDeduplicator(RayBasicDeduplicator): """ def __init__(self, - redis_host: str = 'localhost', - redis_port: PositiveInt = 6380, + redis_address: str = 'redis://localhost:6379', lowercase: bool = False, ignore_non_character: bool = False, *args, **kwargs): """ Initialization method. - :param redis_host: the hostname of redis server - :param redis_port: the port of redis server + :param redis_address: the address of redis server :param lowercase: Whether to convert sample text to lower case :param ignore_non_character: Whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations :param args: extra args :param kwargs: extra args. """ - super().__init__(redis_host=redis_host, - redis_port=redis_port, + super().__init__(redis_address=redis_address, *args, **kwargs) self.lowercase = lowercase diff --git a/data_juicer/ops/deduplicator/ray_image_deduplicator.py b/data_juicer/ops/deduplicator/ray_image_deduplicator.py index 7ca0d10f2..7610dc30d 100644 --- a/data_juicer/ops/deduplicator/ray_image_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_image_deduplicator.py @@ -36,20 +36,17 @@ class RayImageDeduplicator(RayBasicDeduplicator): """ def __init__(self, - redis_host: str = 'localhost', - redis_port: PositiveInt = 6380, + redis_address: str = 'redis://localhost:6379', method: str = 'phash', *args, **kwargs): """ Initialization. - :param redis_host: the hostname of redis server - :param redis_port: the port of redis server + :param redis_address: the address of redis server :param args: extra args :param kwargs: extra args """ - super().__init__(redis_host=redis_host, - redis_port=redis_port, + super().__init__(redis_address=redis_address, *args, **kwargs) if method not in HASH_METHOD: diff --git a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py index 203fcf059..14ce9fc28 100644 --- a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py @@ -128,9 +128,7 @@ def clean(self): @OPERATORS.register_module(OP_NAME) class RayRedisMinhashDeduplicator(Deduplicator): """ - A basic exact matching deduplicator for RAY. - Although its functionality is deduplication, - it is implemented as Filter sub-class. + A MinhashLSH deduplicator based on RAY and Redis. """ def __init__( diff --git a/data_juicer/ops/deduplicator/ray_video_deduplicator.py b/data_juicer/ops/deduplicator/ray_video_deduplicator.py index 902ca1979..342abf7a1 100644 --- a/data_juicer/ops/deduplicator/ray_video_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_video_deduplicator.py @@ -21,19 +21,16 @@ class RayVideoDeduplicator(RayBasicDeduplicator): """ def __init__(self, - redis_host: str = 'localhost', - redis_port: PositiveInt = 6380, + redis_address: str = 'redis://localhost:6379', *args, **kwargs): """ Initialization. - :param redis_host: the hostname of redis server - :param redis_port: the port of redis server + :param redis_address: the address of redis server :param args: extra args :param kwargs: extra args """ - super().__init__(redis_host=redis_host, - redis_port=redis_port, + super().__init__(redis_address=redis_address, *args, **kwargs) diff --git a/docs/Operators.md b/docs/Operators.md index f1a20c9ef..4282f18df 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -13,7 +13,7 @@ The operators in Data-Juicer are categorized into 5 types. | [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data | | [ Mapper ]( #mapper ) | 58 | Edits and transforms samples | | [ Filter ]( #filter ) | 44 | Filters out low-quality samples | -| [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | +| [ Deduplicator ]( #deduplicator ) | 9 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index b1194f250..5adc44e41 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -13,7 +13,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 | | [ Mapper ]( #mapper ) | 58 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 44 | 过滤低质量样本 | -| [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | +| [ Deduplicator ]( #deduplicator ) | 9 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | 下面列出所有具体算子,每种算子都通过多个标签来注明其主要功能。 From 2a1a1a72bde67b4cf0f86ad9b83c7ffcb0835aa8 Mon Sep 17 00:00:00 2001 From: panxuchen Date: Wed, 20 Nov 2024 14:11:16 +0800 Subject: [PATCH 6/6] fix pre-comment --- data_juicer/ops/deduplicator/ray_basic_deduplicator.py | 2 -- data_juicer/ops/deduplicator/ray_document_deduplicator.py | 5 +---- data_juicer/ops/deduplicator/ray_image_deduplicator.py | 5 +---- data_juicer/ops/deduplicator/ray_video_deduplicator.py | 6 +----- 4 files changed, 3 insertions(+), 15 deletions(-) diff --git a/data_juicer/ops/deduplicator/ray_basic_deduplicator.py b/data_juicer/ops/deduplicator/ray_basic_deduplicator.py index 3fb902386..58343e98d 100644 --- a/data_juicer/ops/deduplicator/ray_basic_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_basic_deduplicator.py @@ -1,5 +1,3 @@ -from pydantic import PositiveInt - from data_juicer.utils.constant import HashKeys from data_juicer.utils.lazy_loader import LazyLoader diff --git a/data_juicer/ops/deduplicator/ray_document_deduplicator.py b/data_juicer/ops/deduplicator/ray_document_deduplicator.py index 667f86e38..7720e8810 100644 --- a/data_juicer/ops/deduplicator/ray_document_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_document_deduplicator.py @@ -2,7 +2,6 @@ import string import regex as re -from pydantic import PositiveInt from ..base_op import OPERATORS from .ray_basic_deduplicator import RayBasicDeduplicator @@ -31,9 +30,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args. """ - super().__init__(redis_address=redis_address, - *args, - **kwargs) + super().__init__(redis_address=redis_address, *args, **kwargs) self.lowercase = lowercase self.remove_non_character_regex = re.compile( f'\s+|\d+|[{re.escape(string.punctuation)}]' # noqa: W605 diff --git a/data_juicer/ops/deduplicator/ray_image_deduplicator.py b/data_juicer/ops/deduplicator/ray_image_deduplicator.py index 7610dc30d..d2c85cc59 100644 --- a/data_juicer/ops/deduplicator/ray_image_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_image_deduplicator.py @@ -1,5 +1,4 @@ import numpy as np -from pydantic import PositiveInt from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.mm_utils import load_data_with_context, load_image @@ -46,9 +45,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ - super().__init__(redis_address=redis_address, - *args, - **kwargs) + super().__init__(redis_address=redis_address, *args, **kwargs) if method not in HASH_METHOD: raise ValueError(f'Keep strategy [{method}] is not supported. ' f'Can only be one of {HASH_METHOD}.') diff --git a/data_juicer/ops/deduplicator/ray_video_deduplicator.py b/data_juicer/ops/deduplicator/ray_video_deduplicator.py index 342abf7a1..e3deb23b3 100644 --- a/data_juicer/ops/deduplicator/ray_video_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_video_deduplicator.py @@ -1,7 +1,5 @@ import hashlib -from pydantic import PositiveInt - from data_juicer.utils.mm_utils import (close_video, load_data_with_context, load_video) @@ -30,9 +28,7 @@ def __init__(self, :param args: extra args :param kwargs: extra args """ - super().__init__(redis_address=redis_address, - *args, - **kwargs) + super().__init__(redis_address=redis_address, *args, **kwargs) def calculate_hash(self, sample, context=False): if self.video_key not in sample or not sample[self.video_key]: