Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minhash deduplicator based on RAY. #502

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update docs and code format
  • Loading branch information
chenyushuo committed Dec 16, 2024
commit 62caefe72dba2da6e47fde7a13637909ea06423f
18 changes: 18 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,24 @@ process:
lowercase: true # whether to convert text to lower case
ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash.
tokenizer_model: null # path for the sentencepiece model, used for sentencepiece tokenization.
- ray_bts_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm
tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece]
window_size: 5 # window size of shingling
num_permutations: 256 # number of permutations in minhash computing
jaccard_threshold: 0.7 # the min jaccard similarity threshold in near-duplicate detection. When the jaccard similarity of two sample texts is >= this threshold, they are regarded as similar samples and this op will only keep one of them after deduplication
num_bands: null # number of bands in LSH. Default it's None, and it will be determined by an optimal params computation algorithm by minimize the weighted sum of probs of False Positives and False Negatives
num_rows_per_band: null # number of rows in each band in LSH. Default it's None, and it will be determined by an optimal params computation algorithm
lowercase: true # whether to convert text to lower case
ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash.
tokenizer_model: null # path for the sentencepiece model, used for sentencepiece tokenization.
union_find_parallel_num: 'auto' # number of parallel workers for union-find algorithm. Default it's 'auto', and it will be determined by half of the number of CPUs.
union_threshold: 256 # threshold for minhash values group to perform union-find algorightm.
max_pending_edge_buffer_task: 20 # max number of pending edge buffer ray tasks.
num_edge_buffer_task_returns: 10 # number of edge buffer tasks for `ray.wait` to return.
max_pending_filter_tasks: 20 # max number of pending filter ray tasks.
num_filter_task_returns: 10 # number of filter tasks for `ray.wait` to return.
merge_batch_size: 1000 # batch size for BTS operations.
tmp_file_name: './outputs/ray-dedup-tmp/' # the temporary folder name for deduplication.

# Selector ops
- frequency_specified_field_selector: # selector to select samples based on the sorted frequency of specified field value
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/core/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run(self, load_data_np=None):
from data_juicer.format.formatter import FORMATTERS
dataset = FORMATTERS.modules[obj_name](**args).load_dataset()
else:
dataset = rd.read_json(self.cfg.dataset_path, ray_remote_args=dict(scheduling_strategy="SPREAD"))
dataset = rd.read_json(self.cfg.dataset_path)

# convert all the path in dataset to absolute path
dataset = RayDataset(dataset, self.cfg.dataset_path, self.cfg)
Expand Down
5 changes: 1 addition & 4 deletions data_juicer/ops/deduplicator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
from .ray_basic_deduplicator import RayBasicDeduplicator
from .ray_document_deduplicator import RayDocumentDeduplicator
from .ray_image_deduplicator import RayImageDeduplicator
from .ray_minhash_deduplicator import RayMinhashDeduplicator
from .ray_bts_minhash_deduplicator import RayBTSMinhashDeduplicator
from .ray_redis_minhash_deduplicator import RayRedisMinhashDeduplicator
from .ray_multi_redis_minhash_deduplicator import RayMultiRedisMinhashDeduplicator
from .ray_video_deduplicator import RayVideoDeduplicator
from .video_deduplicator import VideoDeduplicator

Expand All @@ -17,6 +15,5 @@
'DocumentSimhashDeduplicator', 'ImageDeduplicator', 'RayBasicDeduplicator',
'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator',
'RayImageDeduplicator', 'RayRedisMinhashDeduplicator',
'RayMinhashDeduplicator', 'RayBTSMinhashDeduplicator',
'RayMultiRedisMinhashDeduplicator', 'VideoDeduplicator',
'RayBTSMinhashDeduplicator', 'VideoDeduplicator',
]
107 changes: 75 additions & 32 deletions data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
optimal_param, sha1_hash32)


BATCH_SIZE = 1000


@ray.remote
class IdGenerator:
def __init__(self):
self.next_id = 0
def __init__(self, start_id = 0):
self.next_id = start_id

@ray.method(num_returns=2)
def get_next_id(self, count):
Expand All @@ -52,14 +55,14 @@ def get_edges(self, key):
@ray.remote(scheduling_strategy="SPREAD")
class BTSUnionFind:
def __init__(
self,
union_threshold,
parallel_num,
parallel_id,
remote_edge_buffers,
max_pending_edge_buffer_task,
num_edge_buffer_task_returns,
):
self,
union_threshold,
parallel_num,
parallel_id,
remote_edge_buffers,
max_pending_edge_buffer_task,
num_edge_buffer_task_returns,
):
self.union_threshold = union_threshold
self.parallel_num = parallel_num
self.parallel_id = parallel_id
Expand Down Expand Up @@ -114,8 +117,8 @@ def balanced_union_find(self):
return self.old_parent != self.parent

def distribute_edge(self, u, v):
hash_u = u % self.parallel_num
hash_v = v % self.parallel_num
hash_u = u // BATCH_SIZE % self.parallel_num
hash_v = v // BATCH_SIZE % self.parallel_num
if hash_u not in self.edge_list_dict:
self.edge_list_dict[hash_u] = []
self.edge_list_dict[hash_u].append((u, v))
Expand All @@ -130,7 +133,11 @@ def set_edge_buffer(self):
del self.edge_list_dict[self.parallel_id]
else:
self.edge_buffer = []
ray.get(self.remote_edge_buffers[self.parallel_id].set_edges.remote(self.edge_list_dict))
ray.get(
self.remote_edge_buffers[self.parallel_id].set_edges.remote(
self.edge_list_dict
)
)
self.edge_list_dict = {}

def edge_redistribution(self):
Expand All @@ -146,8 +153,9 @@ def communication(self):
self.edge_list_dict = {}
del_list = []
for u, v in self.parent.items():
hash_u = u % self.parallel_num
if self.parent[u] != self.old_parent.get(u, u) or (hash_u != self.parallel_id and v not in self.parent):
hash_u = u // BATCH_SIZE % self.parallel_num
if self.parent[u] != self.old_parent.get(u, u) or \
(hash_u != self.parallel_id and v not in self.parent):
self.distribute_edge(u, v)
if hash_u != self.parallel_id:
del_list.append(u)
Expand Down Expand Up @@ -183,7 +191,7 @@ def union_list(self, x_list):
def rebalancing(self):
new_px_dict = {}
for x in self.parent:
hash_x = x % self.parallel_num
hash_x = x // BATCH_SIZE % self.parallel_num
px = self.find(x)
key = (px, hash_x)
if key not in new_px_dict:
Expand All @@ -192,15 +200,15 @@ def rebalancing(self):
new_px_dict[key] = min(new_px_dict[key], x)
px_set = set(px for px, _ in new_px_dict)
for px in px_set:
hash_px = px % self.parallel_num
hash_px = px // BATCH_SIZE % self.parallel_num
key = (px, hash_px)
if key not in new_px_dict:
new_px_dict[key] = px
else:
new_px_dict[key] = min(new_px_dict[key], px)

for x in self.parent:
hash_x = x % self.parallel_num
hash_x = x // BATCH_SIZE % self.parallel_num
px = self.find(x)
key = (px, hash_x)
if x == new_px_dict[key]:
Expand All @@ -211,7 +219,7 @@ def squeeze(self):
dup_keys = {
x
for x in self.parent
if x % self.parallel_num == self.parallel_id
if x // BATCH_SIZE % self.parallel_num == self.parallel_id
}
self.parent = dup_keys
self.old_parent = {}
Expand Down Expand Up @@ -293,6 +301,22 @@ def __init__(
params computation algorithm
:param tokenizer_model: path for the sentencepiece model, used for
sentencepiece tokenization.
:param union_find_parallel_num: number of parallel workers for
union-find algorithm. Default it's 'auto', and it will be
determined by half of the number of CPUs.
:param union_threshold: threshold for minhash values group to
perform union-find algorightm. Default it's 256.
:param max_pending_edge_buffer_task: max number of pending edge buffer
ray tasks. Default it's 20.
:param num_edge_buffer_task_returns: number of edge buffer tasks for
`ray.wait` to return. Default it's 10.
:param max_pending_filter_tasks: max number of pending filter ray
tasks. Default it's 20.
:param num_filter_task_returns: number of filter tasks for `ray.wait`
to return. Default it's 10.
:param merge_batch_size: batch size for BTS operations. Default
it's 1000.
:param tmp_file_name: the temporary folder name for deduplication.
"""
super().__init__(*args, **kwargs)
# about minhash computation
Expand Down Expand Up @@ -380,7 +404,9 @@ def tokenization_func(text):
).T

if union_find_parallel_num == 'auto':
union_find_parallel_num = int(ray.cluster_resources().get('CPU', 32)) // 2
union_find_parallel_num = int(
ray.cluster_resources().get('CPU') / 2
)
else:
union_find_parallel_num = int(union_find_parallel_num)

Expand Down Expand Up @@ -409,11 +435,20 @@ def tokenization_func(text):
for i in range(self.union_find_parallel_num)
]

self.tmp_file_name = os.path.join(os.getcwd(), tmp_file_name, str(uuid.uuid4()))
self.tmp_file_name = os.path.join(
os.getcwd(), tmp_file_name, str(uuid.uuid4())
)

empty_hash_value = np.full((self.num_rows_per_band,), MAX_HASH, dtype=np.uint32)
self.empty_hash_value = b'\x00\x00\x00\x00' + empty_hash_value.tobytes()
self.empty_hash_table_id = int(MAX_HASH % self.union_find_parallel_num)
empty_hash_value = np.full(
(self.num_rows_per_band,),
MAX_HASH,
dtype=np.uint32
)
self.empty_hash_value = b'\x00\x00\x00\x00' \
+ empty_hash_value.tobytes()
self.empty_hash_table_id = int(
MAX_HASH % self.union_find_parallel_num
)

def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table:
pairs = {}
Expand All @@ -438,15 +473,19 @@ def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table:
).astype(np.uint32)
hash_values = phv.min(axis=0)
for i, (start, end) in enumerate(self.hash_ranges):
hash_value = i.to_bytes(4, 'big') + hash_values[start:end].tobytes()
hash_table_id = hash_values[start] % self.union_find_parallel_num
hash_value = i.to_bytes(4, 'big') \
+ hash_values[start:end].tobytes()
hash_table_id = hash_values[start] \
% self.union_find_parallel_num
if hash_table_id not in pairs:
pairs[hash_table_id] = []
pairs[hash_table_id].append((hash_value, uid))
else:
if self.empty_hash_table_id not in pairs:
pairs[self.empty_hash_table_id] = []
pairs[self.empty_hash_table_id].append((self.empty_hash_value, uid))
pairs[self.empty_hash_table_id].append(
(self.empty_hash_value, uid)
)
result_refs = []
for i, p in pairs.items():
if len(result_refs) > self.max_pending_filter_tasks:
Expand Down Expand Up @@ -494,7 +533,7 @@ def filter_with_union_find(self, samples: pa.Table) -> pa.Table:
query_dict = {}
for idx, uid in enumerate(samples[HashKeys.uid]):
uid = uid.as_py()
hash_id = uid % self.union_find_parallel_num
hash_id = uid // BATCH_SIZE % self.union_find_parallel_num
if hash_id not in query_dict:
query_dict[hash_id] = []
query_dict[hash_id].append((uid, idx))
Expand Down Expand Up @@ -529,10 +568,15 @@ def run(self, dataset):
id_generator = IdGenerator.remote()
def minhash_with_uid(table: pa.Table) -> pa.Table:
num_rows = len(table)
min_id, max_id = ray.get(id_generator.get_next_id.remote(num_rows))
min_id, max_id = ray.get(
id_generator.get_next_id.remote(num_rows)
)
uid_list = range(min_id, max_id)
self.calc_minhash(table[self.text_key], uid_list)
new_table = table.append_column(HashKeys.uid, pa.array(list(uid_list)))
new_table = table.append_column(
HashKeys.uid,
pa.array(list(uid_list))
)
return new_table

dataset.map_batches(
Expand All @@ -543,7 +587,7 @@ def minhash_with_uid(table: pa.Table) -> pa.Table:
self.tmp_file_name,
force_ascii=False
) # TODO: balance file size
dataset = ray.data.read_parquet(self.tmp_file_name, ray_remote_args=dict(scheduling_strategy="SPREAD"))
dataset = ray.data.read_parquet(self.tmp_file_name)
end_time = time.time()
print(f'MinHash time = {end_time - start_time}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> logger.info


Expand All @@ -556,5 +600,4 @@ def minhash_with_uid(table: pa.Table) -> pa.Table:
batch_format='pyarrow',
zero_copy_batch=True,
)
# logger.info(f'origin count = {dataset.count()}, keep count = {result.count()}')
return result
Loading