Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add minhash deduplicator based on RAY. #502

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
29 changes: 29 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,35 @@ process:
redis_port: 6380 # the port of redis instance, please note that the default port of redis is 6379 which is the same as default port for ray, so we need to modify the default redis config to use it in other port
lowercase: false # whether to convert text to lower case
ignore_non_character: false # whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations
- ray_redis_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm
redis_address: 'redis://localhost:6379' # the address of the redis instance
tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece]
window_size: 5 # window size of shingling
num_permutations: 256 # number of permutations in minhash computing
jaccard_threshold: 0.7 # the min jaccard similarity threshold in near-duplicate detection. When the jaccard similarity of two sample texts is >= this threshold, they are regarded as similar samples and this op will only keep one of them after deduplication
num_bands: null # number of bands in LSH. Default it's None, and it will be determined by an optimal params computation algorithm by minimize the weighted sum of probs of False Positives and False Negatives
num_rows_per_band: null # number of rows in each band in LSH. Default it's None, and it will be determined by an optimal params computation algorithm
lowercase: true # whether to convert text to lower case
ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash.
tokenizer_model: null # path for the sentencepiece model, used for sentencepiece tokenization.
- ray_bts_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm
tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece]
window_size: 5 # window size of shingling
num_permutations: 256 # number of permutations in minhash computing
jaccard_threshold: 0.7 # the min jaccard similarity threshold in near-duplicate detection. When the jaccard similarity of two sample texts is >= this threshold, they are regarded as similar samples and this op will only keep one of them after deduplication
num_bands: null # number of bands in LSH. Default it's None, and it will be determined by an optimal params computation algorithm by minimize the weighted sum of probs of False Positives and False Negatives
num_rows_per_band: null # number of rows in each band in LSH. Default it's None, and it will be determined by an optimal params computation algorithm
lowercase: true # whether to convert text to lower case
ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash.
tokenizer_model: null # path for the sentencepiece model, used for sentencepiece tokenization.
union_find_parallel_num: 'auto' # number of parallel workers for union-find algorithm. Default it's 'auto', and it will be determined by half of the number of CPUs.
union_threshold: 256 # threshold for minhash values group to perform union-find algorightm.
max_pending_edge_buffer_task: 20 # max number of pending edge buffer ray tasks.
num_edge_buffer_task_returns: 10 # number of edge buffer tasks for `ray.wait` to return.
max_pending_filter_tasks: 20 # max number of pending filter ray tasks.
num_filter_task_returns: 10 # number of filter tasks for `ray.wait` to return.
merge_batch_size: 1000 # batch size for BTS operations.
tmp_file_name: './outputs/ray-dedup-tmp/' # the temporary folder name for deduplication.

# Selector ops
- frequency_specified_field_selector: # selector to select samples based on the sorted frequency of specified field value
Expand Down
78 changes: 48 additions & 30 deletions data_juicer/core/ray_data.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,39 @@
import os
from functools import partial

import pyarrow as pa
from loguru import logger

from data_juicer import cuda_device_count
from data_juicer.core.data import DJDataset
from data_juicer.ops import Filter, Mapper
from data_juicer.ops import Deduplicator, Filter, Mapper
from data_juicer.utils.constant import Fields
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.process_utils import calculate_np

rd = LazyLoader('rd', 'ray.data')


def is_valid_path(item, dataset_dir):
full_path = os.path.abspath(os.path.join(dataset_dir, item))
return os.path.exists(full_path)
def get_abs_path(path, dataset_dir):
full_path = os.path.abspath(os.path.join(dataset_dir, path))
if os.path.exists(full_path):
return full_path
else:
return path


def convert_to_absolute_paths(dict_with_paths, dataset_dir, path_keys):
def convert_to_absolute_paths(samples, dataset_dir, path_keys):
samples = samples.to_pydict()
for key in path_keys:
if key not in dict_with_paths:
continue
if isinstance(dict_with_paths[key], list):
dict_with_paths[key] = [
os.path.abspath(os.path.join(dataset_dir, item))
if isinstance(item, str) and is_valid_path(dataset_dir, item)
else item for item in dict_with_paths[key]
]
elif isinstance(dict_with_paths[key], str):
dict_with_paths[key] = os.path.abspath(
os.path.join(dataset_dir,
dict_with_paths[key])) if is_valid_path(
dict_with_paths[key],
dataset_dir) else dict_with_paths[key]
return dict_with_paths
for idx in range(len(samples[key])):
paths = samples[key][idx]
if isinstance(paths, str):
samples[key][idx] = get_abs_path(paths, dataset_dir)
elif isinstance(paths, list):
samples[key][idx] = [
get_abs_path(item, dataset_dir) for item in paths
]
return pa.Table.from_pydict(samples)


# TODO: check path for nestdataset
Expand All @@ -43,22 +42,26 @@ def set_dataset_to_absolute_path(dataset, dataset_path, cfg):
Set all the path in input data to absolute path.
Checks dataset_dir and project_dir for valid paths.
"""
if not (cfg.video_key in dataset.columns() or cfg.image_key
in dataset.columns() or cfg.audio_key in dataset.columns()):
return dataset
dataset_dir = os.path.dirname(dataset_path)
dataset = dataset.map(lambda item: convert_to_absolute_paths(
item, dataset_dir, [cfg.video_key, cfg.image_key, cfg.audio_key]))
logger.info(f"transfer {dataset.count()} sample's paths")
path_keys = []
columns = dataset.columns()
for key in [cfg.video_key, cfg.image_key, cfg.audio_key]:
if key in columns:
path_keys.append(key)
if len(path_keys) > 0:
dataset_dir = os.path.dirname(dataset_path)
dataset = dataset.map_batches(partial(convert_to_absolute_paths,
dataset_dir=dataset_dir,
path_keys=path_keys),
batch_format='pyarrow',
zero_copy_batch=True)
return dataset


def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset:
columns = dataset.columns()
if dataset_path:
dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg)
columns = dataset.columns()
if Fields.stats not in columns:
logger.info(f'columns {columns}')

def process_batch_arrow(table: pa.Table) -> pa.Table:
new_column_data = [{} for _ in range(len(table))]
Expand All @@ -77,6 +80,11 @@ def get_num_gpus(op, op_proc):
return 1.0 / proc_per_gpu


def filter_batch(batch, filter_func):
mask = pa.array(filter_func(batch.to_pydict()))
return batch.filter(mask)


class RayDataset(DJDataset):

def __init__(self,
Expand Down Expand Up @@ -122,7 +130,17 @@ def _run_single_op(self, op):
if op.stats_export_path is not None:
self.data.write_json(op.stats_export_path,
force_ascii=False)
self.data = self.data.filter(op.process)
if op.is_batched_op():
self.data = self.data.map_batches(partial(
filter_batch, filter_func=op.process),
batch_format='pyarrow',
batch_size=batch_size,
num_gpus=num_gpus,
zero_copy_batch=True)
else:
self.data = self.data.filter(op.process)
elif isinstance(op, Deduplicator):
self.data = op.run(self.data)
else:
logger.error(
'Ray executor only support Filter and Mapper OPs for now')
Expand Down
5 changes: 4 additions & 1 deletion data_juicer/ops/deduplicator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
from .ray_basic_deduplicator import RayBasicDeduplicator
from .ray_document_deduplicator import RayDocumentDeduplicator
from .ray_image_deduplicator import RayImageDeduplicator
from .ray_bts_minhash_deduplicator import RayBTSMinhashDeduplicator
from .ray_redis_minhash_deduplicator import RayRedisMinhashDeduplicator
from .ray_video_deduplicator import RayVideoDeduplicator
from .video_deduplicator import VideoDeduplicator

__all__ = [
'DocumentDeduplicator', 'DocumentMinhashDeduplicator',
'DocumentSimhashDeduplicator', 'ImageDeduplicator', 'RayBasicDeduplicator',
'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator',
'VideoDeduplicator'
'RayImageDeduplicator', 'RayRedisMinhashDeduplicator',
'RayBTSMinhashDeduplicator', 'VideoDeduplicator',
]
Loading