diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 90fc18875..df4bf91ad 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -643,6 +643,35 @@ process: redis_port: 6380 # the port of redis instance, please note that the default port of redis is 6379 which is the same as default port for ray, so we need to modify the default redis config to use it in other port lowercase: false # whether to convert text to lower case ignore_non_character: false # whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations + - ray_redis_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm + redis_address: 'redis://localhost:6379' # the address of the redis instance + tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece] + window_size: 5 # window size of shingling + num_permutations: 256 # number of permutations in minhash computing + jaccard_threshold: 0.7 # the min jaccard similarity threshold in near-duplicate detection. When the jaccard similarity of two sample texts is >= this threshold, they are regarded as similar samples and this op will only keep one of them after deduplication + num_bands: null # number of bands in LSH. Default it's None, and it will be determined by an optimal params computation algorithm by minimize the weighted sum of probs of False Positives and False Negatives + num_rows_per_band: null # number of rows in each band in LSH. Default it's None, and it will be determined by an optimal params computation algorithm + lowercase: true # whether to convert text to lower case + ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash. + tokenizer_model: null # path for the sentencepiece model, used for sentencepiece tokenization. + - ray_bts_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm + tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece] + window_size: 5 # window size of shingling + num_permutations: 256 # number of permutations in minhash computing + jaccard_threshold: 0.7 # the min jaccard similarity threshold in near-duplicate detection. When the jaccard similarity of two sample texts is >= this threshold, they are regarded as similar samples and this op will only keep one of them after deduplication + num_bands: null # number of bands in LSH. Default it's None, and it will be determined by an optimal params computation algorithm by minimize the weighted sum of probs of False Positives and False Negatives + num_rows_per_band: null # number of rows in each band in LSH. Default it's None, and it will be determined by an optimal params computation algorithm + lowercase: true # whether to convert text to lower case + ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash. + tokenizer_model: null # path for the sentencepiece model, used for sentencepiece tokenization. + union_find_parallel_num: 'auto' # number of parallel workers for union-find algorithm. Default it's 'auto', and it will be determined by half of the number of CPUs. + union_threshold: 256 # threshold for minhash values group to perform union-find algorightm. + max_pending_edge_buffer_task: 20 # max number of pending edge buffer ray tasks. + num_edge_buffer_task_returns: 10 # number of edge buffer tasks for `ray.wait` to return. + max_pending_filter_tasks: 20 # max number of pending filter ray tasks. + num_filter_task_returns: 10 # number of filter tasks for `ray.wait` to return. + merge_batch_size: 1000 # batch size for BTS operations. + tmp_file_name: './outputs/ray-dedup-tmp/' # the temporary folder name for deduplication. # Selector ops - frequency_specified_field_selector: # selector to select samples based on the sorted frequency of specified field value diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 0c131561e..621e68cd9 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -1,11 +1,12 @@ import os +from functools import partial import pyarrow as pa from loguru import logger from data_juicer import cuda_device_count from data_juicer.core.data import DJDataset -from data_juicer.ops import Filter, Mapper +from data_juicer.ops import Deduplicator, Filter, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.process_utils import calculate_np @@ -13,28 +14,26 @@ rd = LazyLoader('rd', 'ray.data') -def is_valid_path(item, dataset_dir): - full_path = os.path.abspath(os.path.join(dataset_dir, item)) - return os.path.exists(full_path) +def get_abs_path(path, dataset_dir): + full_path = os.path.abspath(os.path.join(dataset_dir, path)) + if os.path.exists(full_path): + return full_path + else: + return path -def convert_to_absolute_paths(dict_with_paths, dataset_dir, path_keys): +def convert_to_absolute_paths(samples, dataset_dir, path_keys): + samples = samples.to_pydict() for key in path_keys: - if key not in dict_with_paths: - continue - if isinstance(dict_with_paths[key], list): - dict_with_paths[key] = [ - os.path.abspath(os.path.join(dataset_dir, item)) - if isinstance(item, str) and is_valid_path(dataset_dir, item) - else item for item in dict_with_paths[key] - ] - elif isinstance(dict_with_paths[key], str): - dict_with_paths[key] = os.path.abspath( - os.path.join(dataset_dir, - dict_with_paths[key])) if is_valid_path( - dict_with_paths[key], - dataset_dir) else dict_with_paths[key] - return dict_with_paths + for idx in range(len(samples[key])): + paths = samples[key][idx] + if isinstance(paths, str): + samples[key][idx] = get_abs_path(paths, dataset_dir) + elif isinstance(paths, list): + samples[key][idx] = [ + get_abs_path(item, dataset_dir) for item in paths + ] + return pa.Table.from_pydict(samples) # TODO: check path for nestdataset @@ -43,22 +42,26 @@ def set_dataset_to_absolute_path(dataset, dataset_path, cfg): Set all the path in input data to absolute path. Checks dataset_dir and project_dir for valid paths. """ - if not (cfg.video_key in dataset.columns() or cfg.image_key - in dataset.columns() or cfg.audio_key in dataset.columns()): - return dataset - dataset_dir = os.path.dirname(dataset_path) - dataset = dataset.map(lambda item: convert_to_absolute_paths( - item, dataset_dir, [cfg.video_key, cfg.image_key, cfg.audio_key])) - logger.info(f"transfer {dataset.count()} sample's paths") + path_keys = [] + columns = dataset.columns() + for key in [cfg.video_key, cfg.image_key, cfg.audio_key]: + if key in columns: + path_keys.append(key) + if len(path_keys) > 0: + dataset_dir = os.path.dirname(dataset_path) + dataset = dataset.map_batches(partial(convert_to_absolute_paths, + dataset_dir=dataset_dir, + path_keys=path_keys), + batch_format='pyarrow', + zero_copy_batch=True) return dataset def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset: + columns = dataset.columns() if dataset_path: dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) - columns = dataset.columns() if Fields.stats not in columns: - logger.info(f'columns {columns}') def process_batch_arrow(table: pa.Table) -> pa.Table: new_column_data = [{} for _ in range(len(table))] @@ -77,6 +80,11 @@ def get_num_gpus(op, op_proc): return 1.0 / proc_per_gpu +def filter_batch(batch, filter_func): + mask = pa.array(filter_func(batch.to_pydict())) + return batch.filter(mask) + + class RayDataset(DJDataset): def __init__(self, @@ -122,7 +130,17 @@ def _run_single_op(self, op): if op.stats_export_path is not None: self.data.write_json(op.stats_export_path, force_ascii=False) - self.data = self.data.filter(op.process) + if op.is_batched_op(): + self.data = self.data.map_batches(partial( + filter_batch, filter_func=op.process), + batch_format='pyarrow', + batch_size=batch_size, + num_gpus=num_gpus, + zero_copy_batch=True) + else: + self.data = self.data.filter(op.process) + elif isinstance(op, Deduplicator): + self.data = op.run(self.data) else: logger.error( 'Ray executor only support Filter and Mapper OPs for now') diff --git a/data_juicer/ops/deduplicator/__init__.py b/data_juicer/ops/deduplicator/__init__.py index 56aec0e10..3e9f55f47 100644 --- a/data_juicer/ops/deduplicator/__init__.py +++ b/data_juicer/ops/deduplicator/__init__.py @@ -5,6 +5,8 @@ from .ray_basic_deduplicator import RayBasicDeduplicator from .ray_document_deduplicator import RayDocumentDeduplicator from .ray_image_deduplicator import RayImageDeduplicator +from .ray_bts_minhash_deduplicator import RayBTSMinhashDeduplicator +from .ray_redis_minhash_deduplicator import RayRedisMinhashDeduplicator from .ray_video_deduplicator import RayVideoDeduplicator from .video_deduplicator import VideoDeduplicator @@ -12,5 +14,6 @@ 'DocumentDeduplicator', 'DocumentMinhashDeduplicator', 'DocumentSimhashDeduplicator', 'ImageDeduplicator', 'RayBasicDeduplicator', 'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator', - 'VideoDeduplicator' + 'RayImageDeduplicator', 'RayRedisMinhashDeduplicator', + 'RayBTSMinhashDeduplicator', 'VideoDeduplicator', ] diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py new file mode 100644 index 000000000..ba87edda9 --- /dev/null +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -0,0 +1,619 @@ +import time +import uuid +from collections import defaultdict +from typing import Optional + +import os +import ray +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import regex +from loguru import logger +from pydantic import Field, PositiveInt +from typing_extensions import Annotated +from typing import List, Union + +from data_juicer.utils.constant import HashKeys, Fields +from data_juicer.utils.model_utils import prepare_sentencepiece_model + +from ..base_op import OPERATORS, Deduplicator +from ..common.helper_func import split_on_whitespace +from .document_minhash_deduplicator import (MAX_HASH, MERSENNE_PRIME, + optimal_param, sha1_hash32) + + +BATCH_SIZE = 1000 + + +@ray.remote +class IdGenerator: + def __init__(self, start_id = 0): + self.next_id = start_id + + @ray.method(num_returns=2) + def get_next_id(self, count): + current_id = self.next_id + self.next_id += count + return (current_id, self.next_id) + + +@ray.remote(scheduling_strategy="SPREAD") +class EdgeBuffer: + def __init__(self): + self.edge_dict = {} + + def clear(self): + self.edge_dict = {} + + def set_edges(self, edge_dict): + self.edge_dict = edge_dict + + def get_edges(self, key): + return self.edge_dict.pop(key, []) + + +@ray.remote(scheduling_strategy="SPREAD") +class BTSUnionFind: + """ + A distributed implementation of Union-Find with load balancing. + + The original paper on BTS Union-Find is available at: + https://ieeexplore.ieee.org/document/10598116 + """ + def __init__( + self, + union_threshold, + parallel_num, + parallel_id, + remote_edge_buffers, + max_pending_edge_buffer_task, + num_edge_buffer_task_returns, + ): + self.union_threshold = union_threshold + self.parallel_num = parallel_num + self.parallel_id = parallel_id + self.hash_table = {} + self.parent = {} + self.old_parent = {} + self.remote_edge_buffers = remote_edge_buffers + self.edge_buffer = [] + self.edge_list_dict = {} + self.max_pending_edge_buffer_task = max_pending_edge_buffer_task + self.num_edge_buffer_task_returns = num_edge_buffer_task_returns + + def add_key_value_pairs(self, pairs): + for key, value in pairs: + if key not in self.hash_table: + self.hash_table[key] = [] + self.hash_table[key].append(value) + if len(self.hash_table[key]) > self.union_threshold: + self.hash_table[key] = [self.union_list(self.hash_table[key])] + + def flush_key_value_pairs(self): + for value in self.hash_table.values(): + if len(value) > 1: + self.union_list(value) + del self.hash_table + + def balanced_union_find(self): + for x, y in self.edge_buffer: + self.union(x, y) + self.edge_buffer = [] + result_refs = [] + for remote_edge_buffer in self.remote_edge_buffers: + if len(result_refs) > self.max_pending_edge_buffer_task: + ready_refs, result_refs = ray.wait( + result_refs, + num_returns=self.num_edge_buffer_task_returns + ) + edge_list = ray.get(ready_refs) + for edges in edge_list: + for x, y in edges: + self.union(x, y) + del ready_refs + result_refs.append( + remote_edge_buffer.get_edges.remote(self.parallel_id) + ) + edge_list = ray.get(result_refs) + for edges in edge_list: + for x, y in edges: + self.union(x, y) + del edge_list, result_refs + self.rebalancing() + return self.old_parent != self.parent + + def distribute_edge(self, u, v): + hash_u = u // BATCH_SIZE % self.parallel_num + hash_v = v // BATCH_SIZE % self.parallel_num + if hash_u not in self.edge_list_dict: + self.edge_list_dict[hash_u] = [] + self.edge_list_dict[hash_u].append((u, v)) + if hash_u != hash_v: + if hash_v not in self.edge_list_dict: + self.edge_list_dict[hash_v] = [] + self.edge_list_dict[hash_v].append((u, v)) + + def set_edge_buffer(self): + if self.parallel_id in self.edge_list_dict: + self.edge_buffer = self.edge_list_dict[self.parallel_id] + del self.edge_list_dict[self.parallel_id] + else: + self.edge_buffer = [] + ray.get( + self.remote_edge_buffers[self.parallel_id].set_edges.remote( + self.edge_list_dict + ) + ) + self.edge_list_dict = {} + + def edge_redistribution(self): + self.flush_key_value_pairs() + self.rebalancing() + self.edge_list_dict = {} + for u, v in self.parent.items(): + self.distribute_edge(u, v) + self.parent = {} + self.set_edge_buffer() + + def communication(self): + self.edge_list_dict = {} + del_list = [] + for u, v in self.parent.items(): + hash_u = u // BATCH_SIZE % self.parallel_num + if self.parent[u] != self.old_parent.get(u, u) or \ + (hash_u != self.parallel_id and v not in self.parent): + self.distribute_edge(u, v) + if hash_u != self.parallel_id: + del_list.append(u) + self.old_parent = self.parent.copy() + for u in del_list: + del self.parent[u] + self.set_edge_buffer() + + def find(self, x): + if x not in self.parent: + return x + else: + self.parent[x] = self.find(self.parent[x]) + return self.parent[x] + + def union(self, x, y): + px = self.find(x) + py = self.find(y) + if px == py: + return + if px > py: + px, py = py, px + self.parent[py] = px + + def union_list(self, x_list): + px_list = [self.find(x) for x in x_list] + p = min(px_list) + for px in px_list: + if p != px: + self.parent[px] = p + return p + + def rebalancing(self): + new_px_dict = {} + for x in self.parent: + hash_x = x // BATCH_SIZE % self.parallel_num + px = self.find(x) + key = (px, hash_x) + if key not in new_px_dict: + new_px_dict[key] = x + else: + new_px_dict[key] = min(new_px_dict[key], x) + px_set = set(px for px, _ in new_px_dict) + for px in px_set: + hash_px = px // BATCH_SIZE % self.parallel_num + key = (px, hash_px) + if key not in new_px_dict: + new_px_dict[key] = px + else: + new_px_dict[key] = min(new_px_dict[key], px) + + for x in self.parent: + hash_x = x // BATCH_SIZE % self.parallel_num + px = self.find(x) + key = (px, hash_x) + if x == new_px_dict[key]: + continue + self.parent[x] = new_px_dict[key] + + def squeeze(self): + dup_keys = { + x + for x in self.parent + if x // BATCH_SIZE % self.parallel_num == self.parallel_id + } + self.parent = dup_keys + self.old_parent = {} + self.edge_buffer = [] + ray.get(self.remote_edge_buffers[self.parallel_id].clear.remote()) + + def dup_idx(self, queries): + return [ + idx + for uid, idx in queries + if uid in self.parent + ] + + +OP_NAME = 'ray_bts_minhash_deduplicator' + + +@OPERATORS.register_module(OP_NAME) +class RayBTSMinhashDeduplicator(Deduplicator): + """ + A basic exact matching deduplicator for RAY. + Although its functionality is deduplication, + it is implemented as Filter sub-class. + """ + + # TODO: Set a more reasonable value + EMPTY_HASH_VALUE = 'EMPTY' + _batched_op = True + + def __init__( + self, + tokenization: str = 'space', + window_size: PositiveInt = 5, + lowercase: bool = True, + ignore_pattern: Optional[str] = None, + num_permutations: PositiveInt = 256, + jaccard_threshold: Annotated[float, Field(ge=0, le=1)] = 0.7, + num_bands: Optional[PositiveInt] = None, + num_rows_per_band: Optional[PositiveInt] = None, + tokenizer_model: Optional[str] = None, + union_find_parallel_num: Union[int, str] = 'auto', + union_threshold: Optional[int] = 256, + max_pending_edge_buffer_task: Optional[int] = 20, + num_edge_buffer_task_returns: Optional[int] = 10, + max_pending_filter_tasks: Optional[int] = 20, + num_filter_task_returns: Optional[int] = 10, + merge_batch_size: Optional[int] = 1000, + tmp_file_name: Optional[str] = './outputs/ray-dedup-tmp/', + *args, + **kwargs, + ): + """ + Initialization method. + + :param tokenization: tokenization method for sample texts. It + should be one of [space, punctuation, character, + sentencepiece]. For English-like languages, we recommend + to use 'space', for Chinese-like languages, we recommend + to use 'character', and for multiple languages, we recommend + to use 'sentencepiece'. If using 'sentencepiece', please + provided the model path in the 'tokenizer_model' field. + :param window_size: window size of shingling + :param lowercase: whether to convert text to lower case first + :param ignore_pattern: whether to ignore sub-strings with + specific pattern when computing minhash + :param num_permutations: number of permutations in minhash + computing + :param jaccard_threshold: the min jaccard similarity threshold + in near-duplicate detection. When the jaccard similarity of + two sample texts is >= this threshold, they are regarded as + similar samples and this op will only keep one of them after + deduplication + :param num_bands: number of bands in LSH. Default it's None, and + it will be determined by an optimal params computation + algorithm by minimize the weighted sum of probs of False + Positives and False Negatives + :param num_rows_per_band: number of rows in each band in LSH. + Default it's None, and it will be determined by an optimal + params computation algorithm + :param tokenizer_model: path for the sentencepiece model, used for + sentencepiece tokenization. + :param union_find_parallel_num: number of parallel workers for + union-find algorithm. Default it's 'auto', and it will be + determined by half of the number of CPUs. + :param union_threshold: threshold for minhash values group to + perform union-find algorightm. Default it's 256. + :param max_pending_edge_buffer_task: max number of pending edge buffer + ray tasks. Default it's 20. + :param num_edge_buffer_task_returns: number of edge buffer tasks for + `ray.wait` to return. Default it's 10. + :param max_pending_filter_tasks: max number of pending filter ray + tasks. Default it's 20. + :param num_filter_task_returns: number of filter tasks for `ray.wait` + to return. Default it's 10. + :param merge_batch_size: batch size for BTS operations. Default + it's 1000. + :param tmp_file_name: the temporary folder name for deduplication. + """ + super().__init__(*args, **kwargs) + # about minhash computation + self.tokenization = tokenization + self.window_size = window_size + self.lowercase = lowercase + self.ignore_pattern = ignore_pattern + if self.ignore_pattern: + self.ignore_pattern = regex.compile(self.ignore_pattern) + + # check parameters + if self.ignore_pattern and self.tokenization == 'punctuation': + logger.warning('Be careful that tokenization with punctuations ' + 'won\'t work if the ignore pattern includes ' + 'punctuations.') + self.punctuation_pattern = regex.compile(r'\p{P}') + + if self.tokenization == 'sentencepiece': + if tokenizer_model is None: + raise ValueError("To use 'sentencepiece' tokenization, " + "'tokenizer_model' is required.") + self.tokenizer = prepare_sentencepiece_model(tokenizer_model) + else: + self.tokenizer = None + + if self.tokenization == 'character': + def tokenization_func(text): + return { + str.encode(text[i:i + self.window_size]) + for i in range(len(text) - self.window_size) + } + elif self.tokenization == 'punctuation': + def tokenization_func(text): + tokens = self.punctuation_pattern.split(text) + return { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'space': + def tokenization_func(text): + tokens = split_on_whitespace(text) + return { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'sentencepiece': + def tokenization_func(text): + tokens = self.tokenizer.encode(text, out_type=str) + return { + str.encode(''.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + else: + raise NotImplementedError( + f'Unimplemented tokenization method [{self.tokenization}]') + self.tokenization_func = tokenization_func + + # about deduplication + self.num_permutation = num_permutations + self.jaccard_threshold = jaccard_threshold + self.num_bands = num_bands + self.num_rows_per_band = num_rows_per_band + + # initialize deduplication parameters + # check number of bands and rows + if self.num_bands is None or self.num_rows_per_band is None: + self.num_bands, self.num_rows_per_band = optimal_param( + self.jaccard_threshold, + self.num_permutation, + ) + + # compute hash ranges and create hash tables + self.hash_ranges = [(i * self.num_rows_per_band, + (i + 1) * self.num_rows_per_band) + for i in range(self.num_bands)] + + # generate permutations + gen = np.random.RandomState(seed=42) + self.perm_a, self.perm_b = np.array( + [( + gen.randint(1, MERSENNE_PRIME, dtype=np.uint64), + gen.randint(0, MERSENNE_PRIME, dtype=np.uint64), + ) for _ in range(self.num_permutation)], + dtype=np.uint64, + ).T + + if union_find_parallel_num == 'auto': + union_find_parallel_num = int( + ray.cluster_resources().get('CPU') / 2 + ) + else: + union_find_parallel_num = int(union_find_parallel_num) + + self.max_pending_edge_buffer_task = max_pending_edge_buffer_task + self.num_edge_buffer_task_returns = num_edge_buffer_task_returns + self.max_pending_filter_tasks = max_pending_filter_tasks + self.num_filter_task_returns = num_filter_task_returns + self.merge_batch_size = min(merge_batch_size, union_find_parallel_num) + + logger.info(f'union_find_parallel_num = {union_find_parallel_num}') + self.union_find_parallel_num = union_find_parallel_num + self.union_threshold = union_threshold + self.remote_edge_buffers = [ + EdgeBuffer.remote() + for _ in range(self.union_find_parallel_num) + ] + self.union_find_list = [ + BTSUnionFind.remote( + self.union_threshold, + self.union_find_parallel_num, + i, + self.remote_edge_buffers, + self.max_pending_edge_buffer_task, + self.num_edge_buffer_task_returns, + ) + for i in range(self.union_find_parallel_num) + ] + + self.tmp_file_name = os.path.join( + os.getcwd(), tmp_file_name, str(uuid.uuid4()) + ) + os.makedirs(self.tmp_file_name) + + empty_hash_value = np.full( + (self.num_rows_per_band,), + MAX_HASH, + dtype=np.uint32 + ) + self.empty_hash_value = b'\x00\x00\x00\x00' \ + + empty_hash_value.tobytes() + self.empty_hash_table_id = int( + MAX_HASH % self.union_find_parallel_num + ) + + def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: + pairs = {} + + for text, uid in zip(text_list, uid_list): + text = text.as_py() + if self.lowercase: + text = text.lower() + if self.ignore_pattern: + text = self.ignore_pattern.sub('', text) + + tokens = self.tokenization_func(text) + + if len(tokens) > 0: + hv = np.array( + [sha1_hash32(token) for token in tokens], + dtype=np.uint64 + ) + phv = ( + (hv[:, None] * self.perm_a[None, :] + + self.perm_b) % MERSENNE_PRIME + ).astype(np.uint32) + hash_values = phv.min(axis=0) + for i, (start, end) in enumerate(self.hash_ranges): + hash_value = i.to_bytes(4, 'big') \ + + hash_values[start:end].tobytes() + hash_table_id = hash_values[start] \ + % self.union_find_parallel_num + if hash_table_id not in pairs: + pairs[hash_table_id] = [] + pairs[hash_table_id].append((hash_value, uid)) + else: + if self.empty_hash_table_id not in pairs: + pairs[self.empty_hash_table_id] = [] + pairs[self.empty_hash_table_id].append( + (self.empty_hash_value, uid) + ) + result_refs = [] + for i, p in pairs.items(): + if len(result_refs) > self.max_pending_filter_tasks: + ready_refs, result_refs = ray.wait( + result_refs, + num_returns=self.num_filter_task_returns + ) + ray.get(ready_refs) + result_refs.append( + self.union_find_list[i].add_key_value_pairs.remote(p) + ) + ray.get(result_refs) + + def merge_op_batch(self, object_refs): + results = [] + while object_refs: + ready_refs, object_refs = ray.wait( + object_refs, + num_returns=min(self.merge_batch_size, len(object_refs)) + ) + results.extend(ray.get(ready_refs)) + return results + + def merge(self): + self.merge_op_batch([ + union_find.edge_redistribution.remote() + for union_find in self.union_find_list + ]) + while any( + self.merge_op_batch([ + union_find.balanced_union_find.remote() + for union_find in self.union_find_list + ]) + ): + self.merge_op_batch([ + union_find.communication.remote() + for union_find in self.union_find_list + ]) + self.merge_op_batch([ + union_find.squeeze.remote() + for union_find in self.union_find_list + ]) + + def filter_with_union_find(self, samples: pa.Table) -> pa.Table: + query_dict = {} + for idx, uid in enumerate(samples[HashKeys.uid]): + uid = uid.as_py() + hash_id = uid // BATCH_SIZE % self.union_find_parallel_num + if hash_id not in query_dict: + query_dict[hash_id] = [] + query_dict[hash_id].append((uid, idx)) + mask = np.ones(len(samples), dtype=np.bool_) + result_refs = [] + for hash_id, query in query_dict.items(): + if len(result_refs) > self.max_pending_filter_tasks: + ready_refs, result_refs = ray.wait( + result_refs, + num_returns=self.num_filter_task_returns + ) + results = ray.get(ready_refs) + for result in results: + mask[result] = False + del ready_refs + result_refs.append( + self.union_find_list[hash_id].dup_idx.remote(query) + ) + results = ray.get(result_refs) + for result in results: + mask[result] = False + del query_dict, results + columns_to_keep = [ + name + for name in samples.column_names + if name != HashKeys.uid + ] + return samples.select(columns_to_keep).filter(mask) + + def run(self, dataset): + start_time = time.time() + id_generator = IdGenerator.remote() + def minhash_with_uid(table: pa.Table) -> pa.Table: + num_rows = len(table) + min_id, max_id = ray.get( + id_generator.get_next_id.remote(num_rows) + ) + uid_list = range(min_id, max_id) + self.calc_minhash(table[self.text_key], uid_list) + new_table = table.append_column( + HashKeys.uid, + pa.array(list(uid_list)) + ) + if not new_table[Fields.stats][0].as_py(): + columns_to_keep = [ + name + for name in new_table.column_names + if name != Fields.stats + ] + new_table = new_table.select(columns_to_keep) + pq.write_table( + new_table, + os.path.join(self.tmp_file_name, f'{min_id}.parquet') + ) + return pa.Table.from_arrays([]) + + dataset.map_batches( + minhash_with_uid, + batch_format='pyarrow', + zero_copy_batch=True, + ).materialize() + dataset = ray.data.read_parquet(self.tmp_file_name) + end_time = time.time() + print(f'MinHash time = {end_time - start_time}') + + start_time = time.time() + self.merge() + end_time = time.time() + print(f'merge time = {end_time - start_time}') + result = dataset.map_batches( + self.filter_with_union_find, + batch_format='pyarrow', + zero_copy_batch=True, + ) + return result diff --git a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py new file mode 100644 index 000000000..ee5478e3b --- /dev/null +++ b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py @@ -0,0 +1,380 @@ +import random +import time +import uuid +from collections import defaultdict +from typing import Optional + +import numpy as np +import pandas as pd +import pyarrow as pa +import regex +from loguru import logger +from pydantic import Field, PositiveInt +from typing_extensions import Annotated + +from data_juicer.utils.constant import HashKeys +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import prepare_sentencepiece_model + +from ..base_op import OPERATORS, Deduplicator +from ..common.helper_func import split_on_whitespace +from .document_minhash_deduplicator import (MAX_HASH, MERSENNE_PRIME, + optimal_param, sha1_hash32) + +redis = LazyLoader('redis', 'redis') + + +def retry_on_busy(func): + + def wrapper(*args, **kwargs): + max_retries = 10 + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except Exception as e: + if 'BUSY' in str(e) and attempt < max_retries - 1: + time.sleep(random.uniform(0.1, 0.3) * (2**attempt)) + else: + raise + + return wrapper + + +class RedisUnionFind: + + def __init__(self, + prefix: str, + redis_address: str = 'redis://localhost:6379'): + self.prefix = prefix + self.redis_address = redis_address + self.redis = redis.from_url(url=redis_address) + self.set_key = f'{prefix}_UF_SET' + self.rank_key = f'{prefix}_UF_RANK' + self.incur_id_key = f'{prefix}_UF_INCURID' + + # Lua scripts + self.union_script = self.redis.register_script(""" + local function find(x) + local path = {} + while true do + local parent = redis.call('HGET', KEYS[1], x) + if not parent then + return nil + end + if parent == x then + break + end + table.insert(path, x) + x = parent + end + for _, node in ipairs(path) do + redis.call('HSET', KEYS[1], node, x) + end + return x + end + + local root_x = find(ARGV[1]) + local root_y = find(ARGV[2]) + if not root_x then + redis.call('HSET', KEYS[1], ARGV[1], ARGV[1]) + redis.call('HSET', KEYS[2], ARGV[1], 0) + root_x = ARGV[1] + end + if not root_y then + redis.call('HSET', KEYS[1], ARGV[2], ARGV[2]) + redis.call('HSET', KEYS[2], ARGV[2], 0) + root_y = ARGV[2] + end + if root_x == root_y then + return root_x + end + local rank_x = tonumber(redis.call('HGET', KEYS[2], root_x)) + local rank_y = tonumber(redis.call('HGET', KEYS[2], root_y)) + if rank_x < rank_y then + redis.call('HSET', KEYS[1], root_x, root_y) + return root_y + elseif rank_x > rank_y then + redis.call('HSET', KEYS[1], root_y, root_x) + return root_x + else + redis.call('HSET', KEYS[1], root_y, root_x) + redis.call('HINCRBY', KEYS[2], root_x, 1) + return root_x + end + """) + + def get_uid(self): + return int(self.redis.incr(self.incur_id_key)) + + @retry_on_busy + def union(self, x, y): + return self.union_script(keys=[self.set_key, self.rank_key], + args=[x, y]) + + def is_ancestor(self, x): + ancestor = self.redis.hget(self.set_key, x) + return ancestor is None or int(ancestor) == x + + def __reduce__(self): + return (RedisUnionFind, (self.prefix, self.redis_address)) + + def clean(self): + self.redis.delete(self.set_key, self.rank_key, self.incur_id_key) + + +OP_NAME = 'ray_redis_minhash_deduplicator' + + +@OPERATORS.register_module(OP_NAME) +class RayRedisMinhashDeduplicator(Deduplicator): + """ + A basic exact matching deduplicator for RAY. + Although its functionality is deduplication, + it is implemented as Filter sub-class. + """ + + def __init__( + self, + tokenization: str = 'space', + window_size: PositiveInt = 5, + lowercase: bool = True, + ignore_pattern: Optional[str] = None, + num_permutations: PositiveInt = 256, + jaccard_threshold: Annotated[float, Field(ge=0, le=1)] = 0.7, + num_bands: Optional[PositiveInt] = None, + num_rows_per_band: Optional[PositiveInt] = None, + tokenizer_model: Optional[str] = None, + redis_address: str = 'redis://localhost:6380', + *args, + **kwargs, + ): + """ + Initialization method. + + :param tokenization: tokenization method for sample texts. It + should be one of [space, punctuation, character, + sentencepiece]. For English-like languages, we recommend + to use 'space', for Chinese-like languages, we recommend + to use 'character', and for multiple languages, we recommend + to use 'sentencepiece'. If using 'sentencepiece', please + provided the model path in the 'tokenizer_model' field. + :param window_size: window size of shingling + :param lowercase: whether to convert text to lower case first + :param ignore_pattern: whether to ignore sub-strings with + specific pattern when computing minhash + :param num_permutations: number of permutations in minhash + computing + :param jaccard_threshold: the min jaccard similarity threshold + in near-duplicate detection. When the jaccard similarity of + two sample texts is >= this threshold, they are regarded as + similar samples and this op will only keep one of them after + deduplication + :param num_bands: number of bands in LSH. Default it's None, and + it will be determined by an optimal params computation + algorithm by minimize the weighted sum of probs of False + Positives and False Negatives + :param num_rows_per_band: number of rows in each band in LSH. + Default it's None, and it will be determined by an optimal + params computation algorithm + :param tokenizer_model: path for the sentencepiece model, used for + sentencepiece tokenization. + :param redis_address: address of your redis instance, e.g. + 'redis://localhost:6379' + """ + super().__init__(*args, **kwargs) + # about minhash computation + self.tokenization = tokenization + self.window_size = window_size + self.lowercase = lowercase + self.ignore_pattern = ignore_pattern + if self.ignore_pattern: + self.ignore_pattern = regex.compile(self.ignore_pattern) + + # check parameters + if self.ignore_pattern and self.tokenization == 'punctuation': + logger.warning('Be careful that tokenization with punctuations ' + 'won\'t work if the ignore pattern includes ' + 'punctuations.') + self.punctuation_pattern = regex.compile(r'\p{P}') + + if self.tokenization == 'sentencepiece': + if tokenizer_model is None: + raise ValueError("To use 'sentencepiece' tokenization, " + "'tokenizer_model' is required.") + self.tokenizer = prepare_sentencepiece_model(tokenizer_model) + else: + self.tokenizer = None + + # about deduplication + self.num_permutation = num_permutations + self.jaccard_threshold = jaccard_threshold + self.num_bands = num_bands + self.num_rows_per_band = num_rows_per_band + + # initialize deduplication parameters + # check number of bands and rows + if self.num_bands is None or self.num_rows_per_band is None: + self.num_bands, self.num_rows_per_band = optimal_param( + self.jaccard_threshold, + self.num_permutation, + ) + + # compute hash ranges and create hash tables + self.hash_ranges = [(i * self.num_rows_per_band, + (i + 1) * self.num_rows_per_band) + for i in range(self.num_bands)] + self.hash_tables = [defaultdict(set) for _ in range(self.num_bands)] + + # generate permutations + gen = np.random.RandomState(seed=42) + self.perm_a, self.perm_b = np.array( + [( + gen.randint(1, MERSENNE_PRIME, dtype=np.uint64), + gen.randint(0, MERSENNE_PRIME, dtype=np.uint64), + ) for _ in range(self.num_permutation)], + dtype=np.uint64, + ).T + self.redis_address = redis_address + + def run(self, dataset): + from ray.data.aggregate import AggregateFn + + union_find = RedisUnionFind(prefix=uuid.uuid4().hex[:8], + redis_address=self.redis_address) + + def add_uid_column(table: pa.Table) -> pa.Table: + new_column_data = [union_find.get_uid() for _ in range(len(table))] + new_table = table.append_column(HashKeys.uid, [new_column_data]) + return new_table + + def calculate_minhash(table: pa.Table) -> pa.Table: + ids = table.column(HashKeys.uid).to_pandas() + texts = table.column(self.text_key).to_pandas() + hashes = texts.apply(lambda x: self.compute_minhash(x)) + hashes = pa.Array.from_pandas(hashes).flatten() + + repeated_ids = pa.Array.from_pandas(ids.repeat(self.num_bands)) + + return pa.Table.from_arrays([repeated_ids, hashes], + names=[HashKeys.uid, HashKeys.minhash]) + + def _is_null(r): + return pd.isnull(r) + + class UnionFn(AggregateFn): + + def __init__(self, union_find): + union_find = union_find + + def accumulate(cur, row): + if _is_null(row): + return cur + elif _is_null(cur): + return row[HashKeys.uid] + else: + root = union_find.union(row[HashKeys.uid], cur) + return int(root) + + def merge(a, b): + if _is_null(a): + return b + if _is_null(b): + return a + root = union_find.union(a, b) + return int(root) + + super().__init__( + init=lambda k: None, + accumulate_row=accumulate, + merge=merge, + name='union', + ) + + def filter_with_union_find(table: pa.Table) -> pa.Table: + uids = table.column(HashKeys.uid).to_pandas() + mask = pa.Array.from_pandas( + uids.apply(lambda x: union_find.is_ancestor(x))) + return table.filter(mask) + + dataset_with_id = dataset.map_batches( + add_uid_column, batch_format='pyarrow').materialize() + dataset_with_id.map_batches(calculate_minhash, + batch_format='pyarrow').groupby( + HashKeys.minhash).aggregate( + UnionFn(union_find)).materialize() + result = dataset_with_id.map_batches(filter_with_union_find, + batch_format='pyarrow').materialize() + logger.info(f'Keep {result.count()} samples after MinHash dedup.') + union_find.clean() + return result + + def compute_minhash(self, text): + """ + Compute minhash values for the sample. + + :param sample: input sample + :return: sample with minhash value. + """ + if self.lowercase: + text = text.lower() + if self.ignore_pattern: + text = self.ignore_pattern.sub('', text) + + # get tokens for different tokenization method + tokens = set() + if self.tokenization == 'character': + tokens = { + str.encode(text[i:i + self.window_size]) + for i in range(len(text) - self.window_size) + } + elif self.tokenization == 'punctuation': + tokens = self.punctuation_pattern.split(text) + tokens = { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'space': + tokens = split_on_whitespace(text) + tokens = { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'sentencepiece': + tokens = self.tokenizer.encode(text, out_type=str) + tokens = { + str.encode(''.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + else: + raise NotImplementedError( + f'Unimplemented tokenization method [{self.tokenization}]') + + # # compute minhash value + # hv = np.array([sha1_hash32(token) for token in tokens], + # dtype=np.uint64) + # phv = np.bitwise_and( + # ((hv * np.tile(self.perm_a, + # (len(hv), 1)).T).T + self.perm_b) % MERSENNE_PRIME, + # MAX_HASH) + # hash_values = np.vstack([ + # phv, + # np.ones(self.num_permutation, dtype=np.uint64) * MAX_HASH + # ]).min(axis=0) + if len(tokens) > 0: + hv = np.array( + [sha1_hash32(token) for token in tokens], + dtype=np.uint64 + ) + phv = ( + (hv[:, None] * self.perm_a[None, :] + + self.perm_b) % MERSENNE_PRIME + ).astype(np.uint32) + hash_values = phv.min(axis=0) + else: + hash_values = np.full_like(self.perm_a, MAX_HASH, dtype=np.uint32) + return [ + bytes(hash_values[start:end].byteswap().data) + + start.to_bytes(4, byteorder='little') + for start, end in self.hash_ranges + # groupby minhash||brand_id + ] diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index ab88035b9..1fe8d7002 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -216,6 +216,7 @@ class StatsKeys(object, metaclass=StatsKeysMeta): class HashKeys(object): + uid = DEFAULT_PREFIX + 'uid' hash = DEFAULT_PREFIX + 'hash' minhash = DEFAULT_PREFIX + 'minhash' simhash = DEFAULT_PREFIX + 'simhash' diff --git a/docs/Operators.md b/docs/Operators.md index 7717ba434..0c8d708e6 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -173,6 +173,8 @@ All the specific operators are listed below, each featured with several capabili | document_simhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level using SimHash | [code](../data_juicer/ops/deduplicator/document_simhash_deduplicator.py) | [tests](../tests/ops/deduplicator/test_document_simhash_deduplicator.py) | | image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | Deduplicates samples at document-level using exact matching of images between documents | [code](../data_juicer/ops/deduplicator/image_deduplicator.py) | [tests](../tests/ops/deduplicator/test_image_deduplicator.py) | | video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Deduplicates samples at document-level using exact matching of videos between documents | [code](../data_juicer/ops/deduplicator/video_deduplicator.py) | [tests](../tests/ops/deduplicator/test_video_deduplicator.py) | +| ray_redis_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level using MinHashLSH based on Ray and Redis | [code](../data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py) | - | +| ray_bts_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level using MinHashLSH based on Ray | [code](../data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py) | - | | ray_document_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level by comparing MD5 hash on ray | [code](../data_juicer/ops/deduplicator/ray_document_deduplicator.py) | - | | ray_image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | Deduplicates samples at document-level using exact matching of images between documents on ray | [code](../data_juicer/ops/deduplicator/ray_image_deduplicator.py) | - | | ray_video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Deduplicates samples at document-level using exact matching of videos between documents on ray | [code](../data_juicer/ops/deduplicator/ray_video_deduplicator.py) | - | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 81aee2149..57b009238 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -172,6 +172,8 @@ Data-Juicer 中的算子分为以下 5 种类型。 | document_simhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用 SimHash 在文档级别对样本去重 | [code](../data_juicer/ops/deduplicator/document_simhash_deduplicator.py) | [tests](../tests/ops/deduplicator/test_document_simhash_deduplicator.py) | | image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 使用文档之间图像的精确匹配在文档级别删除重复样本 | [code](../data_juicer/ops/deduplicator/image_deduplicator.py) | [tests](../tests/ops/deduplicator/test_image_deduplicator.py) | | video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 使用文档之间视频的精确匹配在文档级别删除重复样本 | [code](../data_juicer/ops/deduplicator/video_deduplicator.py) | [tests](../tests/ops/deduplicator/test_video_deduplicator.py) | +| ray_redis_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用 MinHashLSH 在文档级别对样本去重,面向 RAY 分布式模式(基于Redis) | [code](../data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py) | - | +| ray_bts_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用 MinHashLSH 在文档级别对样本去重,面向 RAY 分布式模式 | [code](../data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py) | - | | ray_document_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较 MD5 哈希值在文档级别对样本去重,面向RAY分布式模式 | [code](../data_juicer/ops/deduplicator/ray_document_deduplicator.py) | - | | ray_image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 使用文档之间图像的精确匹配在文档级别删除重复样本,面向RAY分布式模式 | [code](../data_juicer/ops/deduplicator/ray_image_deduplicator.py) | - | | ray_video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 使用文档之间视频的精确匹配在文档级别删除重复样本,面向RAY分布式模式 | [code](../data_juicer/ops/deduplicator/ray_video_deduplicator.py) | - | diff --git a/environments/dist_requires.txt b/environments/dist_requires.txt index 4060a654f..b6ab28d06 100644 --- a/environments/dist_requires.txt +++ b/environments/dist_requires.txt @@ -1,2 +1,2 @@ -ray==2.31.0 +ray<=2.38.0 redis>=5.0.0