diff --git a/Makefile b/Makefile index d7693fc..f140b55 100644 --- a/Makefile +++ b/Makefile @@ -20,14 +20,17 @@ serve: up build-doc cd "$(BUILDDIR)" && python3 -m http.server test: up - ${DOCKER} compose exec local poetry run coverage run -m pytest -vvv -s --doctest-modules . --ignore deduplicate-text-datasets --ignore docs --ignore text_dedup/minhash_spark.py --ignore tests/test_benchmark.py + ${DOCKER} compose exec local poetry run coverage run -m pytest --doctest-modules . --ignore deduplicate-text-datasets --ignore docs --ignore text_dedup/minhash_spark.py --ignore tests/benchmark_core.py \ + --ignore tests/benchmark_news.py \ + --ignore tests/sweep_core.py \ + --ignore tests/sweep_news.py ${DOCKER} compose exec local poetry run coverage xml -o cobertura.xml ${DOCKER} compose exec local poetry run coverage report -m ${DOCKER} compose cp local:/app/cobertura.xml cobertura.xml benchmark: up - ${DOCKER} compose exec local poetry run python tests/test_benchmark_core.py - ${DOCKER} compose exec local poetry run python tests/test_benchmark_news.py + ${DOCKER} compose exec local poetry run python tests/benchmark_core.py + ${DOCKER} compose exec local poetry run python tests/benchmark_news.py spark_test: up ${DOCKER} compose exec local poetry run pytest -vvv -s --doctest-modules tests/test_minhash_spark.py diff --git a/README.md b/README.md index bb9bee0..6f101c4 100644 --- a/README.md +++ b/README.md @@ -270,15 +270,15 @@ INFO After : 47045
pinecone/core-2020-05-10-deduplication -See `tests/test_benchmark_core.py` for reproduction. +See `tests/benchmark_core.py` for reproduction. | Algorithm | Precision (Duplicates) | Recall (Duplicates) | Precision (Non Duplicates) | Recall (Non Duplicates) | Macro F1 score | Accuracy | Time | | :------------------------------ | ---------------------: | ------------------: | -------------------------: | ----------------------: | -------------: | --------: | :------- | -| MinHash (Spark) | 0.957 | 0.945 | 0.947 | 0.959 | **0.952** | 0.920 | 698.76s | -| MinHash | 0.959 | 0.945 | 0.947 | 0.962 | **0.953** | 0.924 | 18.80s | -| SimHash | 0.904 | 0.721 | 0.792 | 0.933 | 0.848 | 0.832 | 660.73s | -| UniSim/RETSimNear-Dup + ANN | 0.931 | 0.892 | 0.905 | 0.939 | 0.918 | 0.905 | 1222.87s | -| Exact Title | 0.830 | 0.552 | 0.710 | 0.907 | 0.77 | 0.746 | - | +| UniSim | 0.9307 | 0.8924 | 0.9055 | 0.9394 | 0.9181 | 0.9054 | 1305.79s | +| MinHash Spark | 0.957 | 0.9445 | 0.9471 | 0.959 | 0.952 | 0.9202 | 691.77s | +| MinHash | 0.9594 | 0.9445 | 0.9474 | 0.9616 | **0.9534** | 0.924 | 18.88s | +| SimHash | 0.9042 | 0.721 | 0.792 | 0.9329 | 0.8481 | 0.8321 | 644.36s | +| Exact Title | 0.8302 | 0.5521 | 0.7098 | 0.9065 | 0.77 | 0.7456 | - | | Exact Title Matching [^1] | 0.830 | 0.50 | 0.709 | 0.992 | 0.757 | 0.746 | - | | Simhash Matching [^1] | 0.697 | 0.247 | 0.598 | 0.985 | 0.631 | 0.616 | - | | Document Vector Similarity [^1] | 0.912 | 0.779 | 0.861 | 0.986 | 0.885 | 0.883 | - | @@ -294,29 +294,29 @@ See `tests/test_benchmark_core.py` for reproduction.
NEWS-COPY -See `tests/test_benchmark_news.py` for reproduction. +See `tests/benchmark_news.py` for reproduction. Adjusted Rand Index (ARI) on NEWS-COPY dataset: | Model/Algorithm | ARI | | :----------------------- | :-------- | -| n-gram [^3] | 0.440 | | SimHash | 0.612 | -| SimHash[^2] | 0.695 | | MinHash (Spark) | 0.740 | | MinHash | 0.742 | +| RETSim Near-Dup + ANN* | _0.051_ | +| n-gram [^3] | 0.440 | +| SimHash[^2] | 0.695 | | MinHash[^3] | 0.737 | | MinHash[^2] | 0.783 | | Multilingual USE[^2] | 0.730 | | Multilingual E5-Base[^2] | 0.742 | | S-BERT[^3] | 0.700 | -| RETSim Near-Dup + ANN* | _0.051_ | | RETSim Partial-Dup[^2] | 0.831 | | RETSim Near-Dup[^2] | 0.704 | | Re-ranking [^3] | **0.937** | | Bi-encoder [^3] | 0.915 | -\*: I can't seem to reproduce the results in the paper. +\*: I can't seem to reproduce the results from the paper. [^1]: [Deduplication of Scholarly Documents using Locality Sensitive Hashing and Word Embeddings](https://aclanthology.org/2020.lrec-1.113) [^2]: [RETSim: Resilient and Efficient Text Similarity](https://arxiv.org/abs/2311.17264) diff --git a/cobertura.xml b/cobertura.xml index bd24cf1..b692db1 100644 --- a/cobertura.xml +++ b/cobertura.xml @@ -1,12 +1,12 @@ - + /app - + @@ -20,9 +20,12 @@ - + + + + @@ -38,43 +41,103 @@ + + - - - - - - - + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + + + @@ -84,6 +147,57 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -94,30 +208,34 @@ + + - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + - - - + + + + + + + + @@ -125,44 +243,46 @@ - - + + - + - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + - + - + - + - + @@ -171,44 +291,45 @@ + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + @@ -219,14 +340,14 @@ - + - + @@ -243,92 +364,86 @@ - - - - - - - + + + + + + + - - - + + - - - - - - - - - - + + + + + + + + + - - + + + + + - - + - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + - - - - + + + - - + + + + + + + - - - - - - - - - - - - - + + + + + - - - - - - - - - - - - - - - - - + + + + + + + + + + - + @@ -340,16 +455,16 @@ - + - + @@ -362,146 +477,152 @@ - - - - - - + + + + + + + - - + - + - + + - - - + - - - + + + - - + + + - - - - + + - - + + + + - - - + + - - + + + + + + - - - - + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - + + + + + + + + + + + + + - - - - - - - - + + - + - - - - - - - - - - - - - - - + + + + + + + + + + + + + + - - - + + + - - + + + + - + @@ -517,157 +638,152 @@ - + - - - - - - - - - - - - + + + + + + + + + + + - - - - - - - + + + + + + + - - - - - - - + + + + + + + + - + - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + - - + + - - - - + + + + - - - - - - - - + + + + + + + + - - - - + + + + - - - + + + - + - + - - + + - - + + - - - - - + + + + - - - - - - - - - - + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + - + @@ -684,7 +800,28 @@ + + + + + + + + + + + + + + + + + + + + + @@ -722,16 +859,16 @@ - + - + - - + + @@ -745,7 +882,7 @@ - + @@ -756,40 +893,40 @@ - + - + - - - - - - - - + + + + + + + + + - - - - - - - - - - - + + + + + + + + + + @@ -797,100 +934,140 @@ + - - - - + + + - - - - - - - - - - - - - + + + + + + + + + + + + + + - + + - - - - - - - - - - - - - + + + + + + + + + + + + - - - + + + + - - - - - - - - - - - + + + + + + + + + + - - - - - - - - - - - - - + + + + + + + + + + + + + - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -943,7 +1120,60 @@ - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -957,9 +1187,15 @@ + + + + + + - + @@ -979,6 +1215,9 @@ + + + @@ -998,32 +1237,45 @@ - + - - - - + + + + + - - - - + + + + - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/test_benchmark_core.py b/tests/benchmark_core.py similarity index 100% rename from tests/test_benchmark_core.py rename to tests/benchmark_core.py diff --git a/tests/test_benchmark_news.py b/tests/benchmark_news.py similarity index 100% rename from tests/test_benchmark_news.py rename to tests/benchmark_news.py diff --git a/tests/core_simhash_results.tsv b/tests/core_simhash_results.tsv deleted file mode 100644 index 0a0deb0..0000000 --- a/tests/core_simhash_results.tsv +++ /dev/null @@ -1,31 +0,0 @@ -Algorithm Precision (Duplicates) Recall (Duplicates) Precision (Non Duplicates) Recall (Non Duplicates) Macro F1 score Accuracy Time bit_diff ngram -SimHash 0.9041246513623686 0.7209794696321642 0.791953693073096 0.9328512396694215 0.8480391722177323 0.83211 660.7345383167267 7 3 -SimHash 0.9031564438318401 0.7007165009089937 0.7804158493526874 0.9340219738942623 0.8417861465922638 0.82389 521.5012052059174 7 4 -SimHash 0.9048940464177598 0.6893510708718584 0.7738324731049064 0.9361846188568527 0.8393632597613332 0.81982 478.519406080246 7 5 -SimHash 0.9006558587206496 0.6786470701495411 0.7681192979206224 0.9342997164691965 0.834387578320636 0.81371 253.9916753768921 6 3 -SimHash 0.9020347508001829 0.6746462617022186 0.7658442037903027 0.9355808621791258 0.8339394772952429 0.81266 509.02215027809143 7 6 -SimHash 0.903674655047204 0.6635029646376317 0.7594052611513534 0.9375682494257634 0.8315399580992787 0.80832 521.268883228302 7 7 -SimHash 0.9015530479292694 0.6618061485909479 0.758580224335528 0.9363243039879608 0.8300666361323987 0.80703 209.40148067474365 6 4 -SimHash 0.9030407679475548 0.6583315553659057 0.7567407452416035 0.9376435159043854 0.8298907565945792 0.80603 547.1693751811981 7 8 -SimHash 0.9013889298457728 0.6527365521364966 0.7539378716579158 0.9371085742227906 0.8276634007518443 0.80334 213.39266991615295 6 5 -SimHash 0.9009602938077607 0.6384717126034645 0.7462302151809647 0.9380694224832468 0.8235952544943627 0.79708 207.158132314682 6 6 -SimHash 0.9013104137033562 0.6348482264665757 0.7441641650611587 0.9385736889692586 0.8227372893822574 0.79548 60.34354829788208 5 3 -SimHash 0.9021779028508503 0.6279862027849934 0.7404171866642896 0.9396990609797489 0.8212975447575699 0.79294 204.3841519355774 6 7 -SimHash 0.9017967019443761 0.624190800681431 0.738532653786891 0.9398190045248869 0.8201646778656335 0.79114 229.03749799728394 6 8 -SimHash 0.9007131439813107 0.6243102162565249 0.7386464694373629 0.9391312541223028 0.8196798067093368 0.79084 51.68418741226196 5 4 -SimHash 0.9023949108145192 0.6156235374207548 0.7340281457928517 0.9409367098162056 0.8182115283036855 0.78765 50.14804434776306 5 5 -SimHash 0.9021569935512564 0.6035021357077587 0.7277039155879219 0.9418242260544359 0.8149304545695892 0.78232 49.01416015625 5 6 -SimHash 0.904700909440246 0.5991432328116385 0.7251497702553364 0.9436854255762025 0.8149253398477911 0.78088 24.663214445114136 4 3 -SimHash 0.9054171355087607 0.5968381121919173 0.7238696567240003 0.9442940185181679 0.8146433961163805 0.78005 50.44178080558777 5 7 -SimHash 0.9028739724253997 0.5919410965880156 0.7217012532198779 0.9432402784082312 0.8122876128226388 0.7773 50.281965255737305 5 8 -SimHash 0.9041430525698652 0.5885757903441257 0.7199919189587001 0.9443003955561444 0.8120674857642827 0.77626 23.02803683280945 4 4 -SimHash 0.9058266473760358 0.5831251852165446 0.7169914353049376 0.9457143940255506 0.8114090413404866 0.77417 24.0515878200531 4 5 -SimHash 0.9048858630356428 0.5738169812917971 0.7124621623165229 0.9459695154318647 0.8086740126760829 0.7699 24.691195964813232 4 6 -SimHash 0.9081697916316019 0.5696699676935746 0.7100707040530352 0.9481772762675481 0.8091202478423185 0.76872 25.0836443901062 4 7 -SimHash 0.909447620336231 0.5669879162361079 0.7084912758919318 0.9490880736387669 0.8089694481140814 0.76771 17.250157594680786 3 3 -SimHash 0.9066820276497696 0.5652780125063377 0.7080637839064805 0.9477062129728088 0.807372905778125 0.76651 25.85455298423767 4 8 -SimHash 0.9088453516599994 0.5606859601754978 0.7056379234802764 0.9493078795254031 0.8072416375701379 0.76497 18.530534744262695 3 4 -SimHash 0.9104704097116844 0.5560118784355847 0.7031012337333108 0.9505702698071175 0.8067858217224976 0.76306 19.220607042312622 3 5 -SimHash 0.9104560816696282 0.5501031535514294 0.7002622757682436 0.9510457541239666 0.8053591787189359 0.76044 20.553274869918823 3 6 -SimHash 0.9129719128668488 0.5479213081676405 0.6989474568205358 0.9525962382205944 0.8059596848436923 0.75993 21.3585786819458 3 7 -SimHash 0.9138330757341576 0.5463614407224614 0.6980372420734776 0.9531736184022144 0.8059351589038176 0.75938 22.101569414138794 3 8 diff --git a/tests/news_simhash_results.tsv b/tests/news_simhash_results.tsv deleted file mode 100644 index 5b163ff..0000000 --- a/tests/news_simhash_results.tsv +++ /dev/null @@ -1,49 +0,0 @@ -ARI time bit_diff ngram -0.6822147107954453 12.525881052017212 12 5 -0.675185246041429 5.931649446487427 11 4 -0.6732026066460217 5.965769052505493 10 4 -0.6708917788705451 12.546626567840576 14 8 -0.6637238676788538 12.817933082580566 14 9 -0.6522571909180018 12.567753314971924 13 6 -0.630860315439797 12.620974779129028 13 8 -0.6293158559575469 12.495429515838623 12 6 -0.6220587533491597 5.83250093460083 11 5 -0.6212963580344466 12.526573181152344 12 7 -0.6105532231167057 38.34381318092346 15 8 -0.5964268349580873 12.600701093673706 13 9 -0.5706028925112692 12.653830766677856 14 10 -0.5617146279647695 12.598326444625854 13 7 -0.5474046009449512 12.553319454193115 13 5 -0.5442163045802813 5.9313647747039795 11 6 -0.543946526243722 12.49471664428711 12 8 -0.5399553395495874 5.94044041633606 11 7 -0.5376902189864985 5.817416667938232 10 5 -0.5358110222191643 38.36264395713806 15 9 -0.5300801453038461 38.55435395240784 15 10 -0.5021251623600166 12.63725996017456 13 10 -0.4976895857667319 12.62915849685669 12 9 -0.4564912425443016 12.69089412689209 14 7 -0.4505990754372845 5.907226800918579 10 6 -0.4374959951652071 12.729954242706299 12 10 -0.4355602238623215 6.023991823196411 11 8 -0.39976431697615106 5.911164283752441 10 7 -0.38398141106324357 6.0631513595581055 11 9 -0.34822125655862574 6.080265045166016 11 10 -0.346061974061667 5.9928224086761475 10 8 -0.3328310471696863 12.489759922027588 14 6 -0.30390143723768814 38.99505019187927 15 7 -0.3022484069911815 6.055165529251099 10 9 -0.22891970231583927 6.0675036907196045 10 10 -0.15511192127369075 13.053264617919922 12 4 -0.10248533847317527 6.086925506591797 10 3 -0.08391493975721878 12.538282632827759 14 5 -0.053372269805519144 38.39644193649292 15 6 -0.04762594558725042 12.583284139633179 13 4 -0.0288788685718963 5.963608503341675 11 3 -0.02689731257283987 39.12379431724548 15 5 -0.014885685897021182 12.797574996948242 14 4 -0.01057028536897794 12.928021430969238 12 3 -0.005034380255319022 13.08185863494873 13 3 -0.00488625777725783 38.875048875808716 15 4 -0.0023038646093864255 13.152316570281982 14 3 -0.0008913682980693251 41.63720178604126 15 3 diff --git a/tests/test_ccnet.py b/tests/test_ccnet.py new file mode 100644 index 0000000..767f655 --- /dev/null +++ b/tests/test_ccnet.py @@ -0,0 +1,37 @@ +import subprocess # nosec + + +def test_exact_hash(): + result = subprocess.run( + [ + "python", + "-m", + "text_dedup.ccnet", + "--path", + "allenai/c4", + "--name", + "xh", + "--split", + "train", + "--cache_dir", + ".cache", + "--output", + ".temp-output", + "--column", + "text", + "--batch_size", + "10000", + ], + capture_output=True, + text=True, + ) # nosec + + # check the output + print(f"Output:\n{result.stdout}") + assert ( + "69048" in result.stdout and "68221" in result.stdout + ), f"Expected before and after are not present in the output: {result.stdout}" + + # remove the output and input + # subprocess.run(["rm", "-rf", ".cache"]) # nosec + subprocess.run(["rm", "-rf", ".temp-output"]) # nosec diff --git a/tests/test_unisim.py b/tests/test_unisim.py new file mode 100644 index 0000000..1283be6 --- /dev/null +++ b/tests/test_unisim.py @@ -0,0 +1,36 @@ +import subprocess # nosec + + +def test_minhash(): + result = subprocess.run( + [ + "python", + "-m", + "text_dedup.ann_unisim", + "--path", + "truthful_qa", + "--name", + "generation", + "--split", + "validation", + "--cache_dir", + ".cache", + "--output", + ".temp-output", + "--column", + "question", + "--batch_size", + "24", + ], + capture_output=True, + text=True, + ) # nosec + + # check the output + assert ( + "817" in result.stdout and "788" in result.stdout + ), f"Expected before and after are not present in the output: {result.stdout}" + + # remove the output and input + # subprocess.run(["rm", "-rf", ".cache"]) # nosec + subprocess.run(["rm", "-rf", ".temp-output"]) # nosec diff --git a/text_dedup/ann_unisim.py b/text_dedup/ann_unisim.py index 1fb0234..7b78e32 100644 --- a/text_dedup/ann_unisim.py +++ b/text_dedup/ann_unisim.py @@ -1,6 +1,5 @@ import inspect import os -import pickle # nosec import random from pathlib import Path @@ -12,14 +11,18 @@ from unisim.embedder import Embedder from text_dedup import logger -from text_dedup.utils.args import IOArgs -from text_dedup.utils.args import MetaArgs -from text_dedup.utils.args import UniSimArgs -from text_dedup.utils.inspect import random_samples -from text_dedup.utils.load import load_hf_dataset -from text_dedup.utils.memory import DisableReferenceCount -from text_dedup.utils.timer import Timer -from text_dedup.utils.union_find import UnionFind +from text_dedup.utils import CLUSTER_COLUMN +from text_dedup.utils import INDEX_COLUMN +from text_dedup.utils import DisableReferenceCount +from text_dedup.utils import IOArgs +from text_dedup.utils import MetaArgs +from text_dedup.utils import Timer +from text_dedup.utils import UnionFind +from text_dedup.utils import UniSimArgs +from text_dedup.utils import load_hf_dataset +from text_dedup.utils import random_samples + +EMBEDDING_COLUMN = "__embeddings__" class WrapInferenceSession: @@ -67,26 +70,16 @@ def main(io_args: IOArgs, meta_args: MetaArgs, unisim_args: UniSimArgs): with timer("Total"): with timer("Loading"): - ds = load_hf_dataset(io_args) - if meta_args.idx_column is not None: - original_idx = ds[meta_args.idx_column] - else: - original_idx = list(range(len(ds))) - - ds = ds.map(lambda x, i: {"__idx__": i}, with_indices=True, num_proc=io_args.num_proc) - meta_args.idx_column = "__idx__" - id2id = {new: old for new, old in zip(ds["__idx__"], original_idx)} + ds, id2id = load_hf_dataset(io_args=io_args, meta_args=meta_args) with timer("Embedding"): ds = ds.map( lambda batch: { - "__embeddings__": text_sim.embedder.embed(batch[meta_args.column]), + EMBEDDING_COLUMN: text_sim.embedder.embed(batch[meta_args.column]), }, num_proc=io_args.num_proc, batched=True, batch_size=meta_args.batch_size, - new_fingerprint="Thisisatestb", - cache_file_name="Thisisatestb.b", load_from_cache_file=True, ) @@ -103,8 +96,8 @@ def main(io_args: IOArgs, meta_args: MetaArgs, unisim_args: UniSimArgs): shard = ds.shard( num_shards=NUM_SHARDS, index=batch_idx, contiguous=True, writer_batch_size=meta_args.batch_size ) - batch_indices = shard[meta_args.idx_column] - batch_embedds = shard["__embeddings__"] + batch_indices = shard[INDEX_COLUMN] + batch_embedds = shard[EMBEDDING_COLUMN] text_sim.indexer.add(batch_embedds, batch_indices) if unisim_args.store_data: text_sim.indexed_data.extend(shard[meta_args.column]) @@ -121,8 +114,8 @@ def main(io_args: IOArgs, meta_args: MetaArgs, unisim_args: UniSimArgs): num_shards=NUM_SHARDS, index=batch_idx, contiguous=True, writer_batch_size=meta_args.batch_size ) - remain_embedds = shard["__embeddings__"] - remain_indices = shard[meta_args.idx_column] + remain_embedds = shard[EMBEDDING_COLUMN] + remain_indices = shard[INDEX_COLUMN] shard_results = [[] for _ in remain_indices] k = 20 while remain_embedds and remain_indices: @@ -151,7 +144,7 @@ def main(io_args: IOArgs, meta_args: MetaArgs, unisim_args: UniSimArgs): k *= 2 - results.extend(zip(shard[meta_args.idx_column], shard_results)) + results.extend(zip(shard[INDEX_COLUMN], shard_results)) with timer("Clustering"): for idx, matches in tqdm(results): @@ -160,15 +153,15 @@ def main(io_args: IOArgs, meta_args: MetaArgs, unisim_args: UniSimArgs): with timer("Filtering"), DisableReferenceCount(): ds = ds.map( - function=lambda _, idx: {"__cluster__": uf.find(idx)}, - with_indices=True, + function=lambda record: {CLUSTER_COLUMN: uf.find(record[INDEX_COLUMN])}, + with_indices=False, num_proc=io_args.num_proc, # type: ignore new_fingerprint=str(random.getrandbits(128)), # type: ignore desc="Finding clusters...", # type: ignore ) final_data = ds.filter( - function=lambda record, idx: record["__cluster__"] == idx, - with_indices=True, + function=lambda record: record[CLUSTER_COLUMN] == record[INDEX_COLUMN], + with_indices=False, num_proc=io_args.num_proc, desc="Filtering clusters...", ) @@ -180,12 +173,7 @@ def main(io_args: IOArgs, meta_args: MetaArgs, unisim_args: UniSimArgs): final_data = final_data.remove_columns(["__cluster__"]) final_data.save_to_disk(io_args.output) if io_args.debug: - with open(os.path.join(io_args.output, "uf.pkl"), "wb") as f: - # use the original index instead of the new one - new_uf = UnionFind() - for key in uf.parent: - new_uf.union(id2id[key], id2id[uf.find(key)]) - pickle.dump(new_uf, f, protocol=pickle.HIGHEST_PROTOCOL) + uf.dump(os.path.join(io_args.output, "uf.pkl"), id2id=id2id) with timer("Cleaning"): if io_args.clean_cache: diff --git a/text_dedup/bloom_filter.py b/text_dedup/bloom_filter.py index 9c4c0c8..41d9f12 100644 --- a/text_dedup/bloom_filter.py +++ b/text_dedup/bloom_filter.py @@ -6,7 +6,6 @@ import click import numpy as np -from datasets import Dataset from pybloom_live import ScalableBloomFilter from tqdm import tqdm @@ -50,7 +49,7 @@ def main( with timer("Total"): with timer("Loading"): - ds: Dataset = load_hf_dataset(io_args) + ds, _ = load_hf_dataset(io_args=io_args, meta_args=meta_args) LEN_DATASET = len(ds) NUM_SHARDS = int(np.ceil(LEN_DATASET / meta_args.batch_size)) diff --git a/text_dedup/ccnet.py b/text_dedup/ccnet.py index 4be0301..19d8a79 100644 --- a/text_dedup/ccnet.py +++ b/text_dedup/ccnet.py @@ -12,24 +12,26 @@ import click import numpy as np -from datasets import Dataset from tqdm import tqdm from text_dedup import logger +from text_dedup.utils import INDEX_COLUMN +from text_dedup.utils import DisableReferenceCount from text_dedup.utils import ExactHashArgs from text_dedup.utils import IOArgs from text_dedup.utils import MetaArgs -from text_dedup.utils.hashfunc import md5_digest -from text_dedup.utils.hashfunc import sha256_digest -from text_dedup.utils.hashfunc import xxh3_64_digest -from text_dedup.utils.hashfunc import xxh3_128_digest -from text_dedup.utils.load import load_hf_dataset -from text_dedup.utils.memory import DisableReferenceCount -from text_dedup.utils.preprocess import normalize as normalize_for_dedup -from text_dedup.utils.timer import Timer +from text_dedup.utils import Timer +from text_dedup.utils import load_hf_dataset +from text_dedup.utils import md5_digest +from text_dedup.utils import normalize as normalize_for_dedup +from text_dedup.utils import sha256_digest +from text_dedup.utils import xxh3_64_digest +from text_dedup.utils import xxh3_128_digest HASH_SIZE = np.uint64(0).nbytes # 8 bytes mp.set_start_method("fork", force=True) +HASH_COLUMN = "__hash__" +ID_COLUMN = "__id__" def compute_hashes( @@ -61,9 +63,9 @@ def compute_hashes( n = len(lines) hashes = [hash_func(bytes(normalize_for_dedup(line), encoding="utf-8")) for line in lines] return { - "__hash__": hashes, - "__id__": [idx for _ in range(n)], - "__idx__": list(range(n)), + HASH_COLUMN: hashes, + ID_COLUMN: [idx for _ in range(n)], + INDEX_COLUMN: list(range(n)), } @@ -131,7 +133,7 @@ def xxh3_digest_sized(data: bytes) -> bytes: with timer("Total"): with timer("Loading"): - ds: Dataset = load_hf_dataset(io_args) + ds, _ = load_hf_dataset(io_args=io_args, meta_args=meta_args) LEN_DATASET = len(ds) hashes = set() @@ -142,18 +144,17 @@ def xxh3_digest_sized(data: bytes) -> bytes: compute_hashes, batched=True, batch_size=1, - with_indices=True if meta_args.idx_column is None else False, + with_indices=False, num_proc=io_args.num_proc, - fn_kwargs={"column": meta_args.column, "hash_func": hash_func} - | ({"idx_column": meta_args.idx_column, "idx": None} if meta_args.idx_column is not None else {}), - remove_columns=ds.column_names, + fn_kwargs={"column": meta_args.column, "hash_func": hash_func, "idx_column": INDEX_COLUMN, "idx": None}, + remove_columns=[c for c in ds.column_names if c != INDEX_COLUMN], desc="Computing hashes...", ) NUM_SHARDS = int(np.ceil(len(hashed) / meta_args.batch_size)) for batch_idx in tqdm(range(0, NUM_SHARDS), desc="Processing..."): ds_shard = hashed.shard(NUM_SHARDS, batch_idx, contiguous=True) for h, id_, idx in tqdm( - zip(ds_shard["__hash__"], ds_shard["__id__"], ds_shard["__idx__"]), + zip(ds_shard[HASH_COLUMN], ds_shard[ID_COLUMN], ds_shard[INDEX_COLUMN]), leave=False, ): if h in hashes: diff --git a/text_dedup/exact_hash.py b/text_dedup/exact_hash.py index fc33170..6f8254a 100644 --- a/text_dedup/exact_hash.py +++ b/text_dedup/exact_hash.py @@ -6,7 +6,6 @@ import click import numpy as np -from datasets import Dataset from tqdm import tqdm from text_dedup import logger @@ -46,7 +45,7 @@ def main( with timer("Total"): with timer("Loading"): - ds: Dataset = load_hf_dataset(io_args) + ds, _ = load_hf_dataset(io_args=io_args, meta_args=meta_args) LEN_DATASET: int = len(ds) NUM_SHARDS = int(np.ceil(LEN_DATASET / meta_args.batch_size)) diff --git a/text_dedup/minhash.py b/text_dedup/minhash.py index e4a46fb..33b0f17 100644 --- a/text_dedup/minhash.py +++ b/text_dedup/minhash.py @@ -5,7 +5,6 @@ import multiprocessing as mp import os -import pickle # nosec import random import re from collections import defaultdict @@ -15,22 +14,23 @@ import click import datasets import numpy as np -from datasets import Dataset from tqdm import tqdm from text_dedup import logger +from text_dedup.utils import CLUSTER_COLUMN +from text_dedup.utils import INDEX_COLUMN +from text_dedup.utils import DisableReferenceCount +from text_dedup.utils import IOArgs +from text_dedup.utils import MetaArgs +from text_dedup.utils import MinHashArgs +from text_dedup.utils import Timer from text_dedup.utils import UnionFind +from text_dedup.utils import load_hf_dataset from text_dedup.utils import ngrams -from text_dedup.utils.analysis import optimal_param -from text_dedup.utils.args import IOArgs -from text_dedup.utils.args import MetaArgs -from text_dedup.utils.args import MinHashArgs -from text_dedup.utils.hashfunc import sha1_hash -from text_dedup.utils.hashfunc import xxh3_16hash -from text_dedup.utils.hashfunc import xxh3_32hash -from text_dedup.utils.load import load_hf_dataset -from text_dedup.utils.memory import DisableReferenceCount -from text_dedup.utils.timer import Timer +from text_dedup.utils import optimal_param +from text_dedup.utils import sha1_hash +from text_dedup.utils import xxh3_16hash +from text_dedup.utils import xxh3_32hash SEED = 42 RNG = np.random.RandomState(SEED) @@ -40,6 +40,7 @@ # is not copied to child processes as long as it is not modified. mp.set_start_method("fork", force=True) uf = UnionFind() +SIGNATURE_COLUMN = "__signatures__" def embed_func( @@ -106,9 +107,9 @@ def embed_func( ... max_hash=max_hash, ... modulo_prime=modulo_prime, ... ) - >>> len(res["__signatures__"]) + >>> len(res[SIGNATURE_COLUMN]) 10 - >>> res["__id__"] + >>> res[INDEX_COLUMN] 0 """ # a, b are each np.ndarray arrays containing {num_perm} pairs of random numbers used for building new hashes @@ -133,7 +134,7 @@ def embed_func( # keeping for backward compatibility, even though theoretically and empirically # it doesnt matter if it is there or not. github.com/ekzhu/datasketch/issues/114 Hs: list[bytes] = [bytes(hashvalues[start:end].byteswap().data) for start, end in hashranges] - return {"__signatures__": Hs, "__id__": idx} + return {SIGNATURE_COLUMN: Hs, INDEX_COLUMN: idx} @click.command @@ -213,7 +214,7 @@ def hash_func(byte_data): with timer("Total"): with timer("Loading"): - ds: Dataset = load_hf_dataset(io_args) + ds, id2id = load_hf_dataset(io_args=io_args, meta_args=meta_args) ds = ds.filter( lambda x: len(NON_ALPHA.split(x[meta_args.column].lower())) >= minhash_args.min_length, num_proc=io_args.num_proc, @@ -235,12 +236,10 @@ def hash_func(byte_data): "max_hash": MAX_HASH, "modulo_prime": MODULO_PRIME, }, - input_columns=( - [meta_args.column] if meta_args.idx_column is None else [meta_args.column, meta_args.idx_column] - ), - remove_columns=ds.column_names, + input_columns=[meta_args.column, INDEX_COLUMN], + remove_columns=[col for col in ds.column_names if col != INDEX_COLUMN], num_proc=io_args.num_proc, - with_indices=True if meta_args.idx_column is None else False, + with_indices=False, desc="Fingerprinting...", ) LEN_EMBEDDED = len(embedded) @@ -259,7 +258,7 @@ def hash_func(byte_data): contiguous=True, writer_batch_size=meta_args.batch_size, ) - for key, Hs in zip(embedded_shard["__id__"], embedded_shard["__signatures__"]): + for key, Hs in zip(embedded_shard[INDEX_COLUMN], embedded_shard[SIGNATURE_COLUMN]): for i, H in enumerate(Hs): HASH_TABLES[i][H].add(key) @@ -277,8 +276,8 @@ def hash_func(byte_data): with timer("Filtering"), DisableReferenceCount(): ds = ds.map( - function=lambda _, idx: {"__cluster__": uf.find(idx)}, - with_indices=True, + function=lambda record: {CLUSTER_COLUMN: uf.find(record[INDEX_COLUMN])}, + with_indices=False, num_proc=io_args.num_proc, new_fingerprint=str(random.getrandbits(128)), desc="Finding clusters...", @@ -287,18 +286,17 @@ def hash_func(byte_data): # Since there is no easy groupby in datasets # I will use this simple filter for now final_data = ds.filter( - function=lambda record, idx: record["__cluster__"] == idx, - with_indices=True, + function=lambda record: record[CLUSTER_COLUMN] == record[INDEX_COLUMN], + with_indices=False, num_proc=io_args.num_proc, desc="Filtering clusters...", ) with timer("Saving"): - final_data = final_data.remove_columns(["__cluster__"]) + final_data = final_data.remove_columns([CLUSTER_COLUMN, INDEX_COLUMN]) final_data.save_to_disk(io_args.output) if io_args.debug: - with open(os.path.join(io_args.output, "uf.pkl"), "wb") as f: - pickle.dump(uf, f, protocol=pickle.HIGHEST_PROTOCOL) + uf.dump(os.path.join(io_args.output, "uf.pkl"), id2id=id2id) with timer("Cleaning"): if io_args.clean_cache: diff --git a/text_dedup/simhash.py b/text_dedup/simhash.py index 7ea5224..a6806c3 100644 --- a/text_dedup/simhash.py +++ b/text_dedup/simhash.py @@ -6,7 +6,6 @@ import math import multiprocessing as mp import os -import pickle # nosec import random from collections import defaultdict from itertools import permutations @@ -18,24 +17,27 @@ import numpy as np from bitarray import bitarray from bitarray import frozenbitarray -from datasets import Dataset from tqdm import tqdm from text_dedup import logger +from text_dedup.utils import CLUSTER_COLUMN +from text_dedup.utils import INDEX_COLUMN +from text_dedup.utils import DisableReferenceCount from text_dedup.utils import IOArgs from text_dedup.utils import MetaArgs from text_dedup.utils import SimHashArgs +from text_dedup.utils import Timer from text_dedup.utils import UnionFind +from text_dedup.utils import load_hf_dataset from text_dedup.utils import ngrams -from text_dedup.utils.hashfunc import xxh3_64_digest -from text_dedup.utils.hashfunc import xxh3_128_digest -from text_dedup.utils.load import load_hf_dataset -from text_dedup.utils.memory import DisableReferenceCount -from text_dedup.utils.timer import Timer +from text_dedup.utils import xxh3_64_digest +from text_dedup.utils import xxh3_128_digest mp.set_start_method("fork", force=True) datasets.logging.set_verbosity_error() uf = UnionFind() +KEY_COLUMN = "__keys__" +SIGNATURE_COLUMN = "__signature__" def _hamming_distance(a: bitarray, b: bitarray) -> int: @@ -204,7 +206,6 @@ def _create_permutations(f: int, k: int, b: int) -> list[Permutation]: y = (f - x * max_block_size) // min_block_size break - logger.info(f"{x=} w/ {max_block_size}, {y=} w/ {min_block_size}") assert ( x * max_block_size + y * min_block_size == f ), f"{x=} w/ {max_block_size}, {y=} w/ {min_block_size} are invalid" @@ -330,9 +331,9 @@ def embed_func( Examples -------- >>> res = embed_func("hello world", 0, ngram=3, permutations=None, hash_func=xxh3_64_digest) - >>> res["__id__"] + >>> res[INDEX_COLUMN] 0 - >>> len(res["__signature__"]) + >>> len(res[SIGNATURE_COLUMN]) 8 """ tokens = {bytes("".join(ng).lower(), "utf-8") for ng in ngrams(list(content), n=ngram)} @@ -346,7 +347,7 @@ def embed_func( (permutation.permute(sig) & permutation.search_mask).tobytes(), ) ) - return {"__id__": idx, "__keys__": keys, "__signature__": sig.tobytes()} + return {INDEX_COLUMN: idx, KEY_COLUMN: keys, SIGNATURE_COLUMN: sig.tobytes()} @click.command @@ -369,7 +370,7 @@ def main( with timer("Total"): with timer("Loading"): - ds: Dataset = load_hf_dataset(io_args) + ds, id2id = load_hf_dataset(io_args=io_args, meta_args=meta_args) LEN_DATASET = len(ds) # type: ignore @@ -381,12 +382,10 @@ def main( "permutations": PERMUTATIONS, "hash_func": hash_func, }, - input_columns=( - [meta_args.column] if meta_args.idx_column is None else [meta_args.column, meta_args.idx_column] - ), + input_columns=[meta_args.column, INDEX_COLUMN], remove_columns=[meta_args.column], num_proc=io_args.num_proc, # type: ignore - with_indices=True if meta_args.idx_column is None else False, + with_indices=False, desc="SimHashing...", # type: ignore ) @@ -404,7 +403,7 @@ def main( num_shards=NUM_SHARDS, index=batch_idx, contiguous=True, writer_batch_size=meta_args.batch_size ) for idx, keys, sig in tqdm( - zip(embedded_shard["__id__"], embedded_shard["__keys__"], embedded_shard["__signature__"]), + zip(embedded_shard[INDEX_COLUMN], embedded_shard[KEY_COLUMN], embedded_shard[SIGNATURE_COLUMN]), desc="Indexing...", leave=False, total=len(embedded_shard), @@ -429,8 +428,8 @@ def main( with timer("Filtering"), DisableReferenceCount(): ds = ds.map( - function=lambda _, idx: {"__cluster__": uf.find(idx)}, - with_indices=True, + function=lambda record: {CLUSTER_COLUMN: uf.find(record[INDEX_COLUMN])}, + with_indices=False, num_proc=io_args.num_proc, # type: ignore new_fingerprint=str(random.getrandbits(128)), # type: ignore desc="Finding clusters...", # type: ignore @@ -439,18 +438,17 @@ def main( # Since there is no easy groupby in datasets # I will use this simple filter for now final_data = ds.filter( - function=lambda record, idx: record["__cluster__"] == idx, - with_indices=True, + function=lambda record: record[CLUSTER_COLUMN] == record[INDEX_COLUMN], + with_indices=False, num_proc=io_args.num_proc, desc="Filtering clusters...", ) with timer("Saving"): - final_data = final_data.remove_columns(["__cluster__"]) + final_data = final_data.remove_columns([CLUSTER_COLUMN, INDEX_COLUMN]) final_data.save_to_disk(io_args.output) if io_args.debug: - with open(os.path.join(io_args.output, "uf.pkl"), "wb") as f: - pickle.dump(uf, f, protocol=pickle.HIGHEST_PROTOCOL) + uf.dump(path=os.path.join(io_args.output, "uf.pkl"), id2id=id2id) with timer("Cleaning"): if io_args.clean_cache: diff --git a/text_dedup/suffix_array.py b/text_dedup/suffix_array.py index b695add..3fe7b4d 100644 --- a/text_dedup/suffix_array.py +++ b/text_dedup/suffix_array.py @@ -16,7 +16,6 @@ import click import datasets -from datasets import Dataset from text_dedup import logger from text_dedup.utils import IOArgs @@ -321,7 +320,7 @@ def main( with timer("Total"): with timer("Loading"): - ds: Dataset = load_hf_dataset(io_args) + ds, _ = load_hf_dataset(io_args=io_args, meta_args=meta_args) with timer("Preprocessing"): offsets: list[slice] = [] diff --git a/text_dedup/utils/__init__.py b/text_dedup/utils/__init__.py index 84f7460..f82b532 100644 --- a/text_dedup/utils/__init__.py +++ b/text_dedup/utils/__init__.py @@ -2,6 +2,7 @@ # @Date : 2022-12-26 15:42:09 # @Author : Chenghao Mou (mouchenghao@gmail.com) +from text_dedup.utils.analysis import optimal_param from text_dedup.utils.args import BloomFilterArgs from text_dedup.utils.args import ExactHashArgs from text_dedup.utils.args import IOArgs @@ -10,8 +11,27 @@ from text_dedup.utils.args import SAArgs from text_dedup.utils.args import SimHashArgs from text_dedup.utils.args import UniSimArgs +from text_dedup.utils.const import CLUSTER_COLUMN +from text_dedup.utils.const import INDEX_COLUMN +from text_dedup.utils.hashfunc import md5 +from text_dedup.utils.hashfunc import md5_digest +from text_dedup.utils.hashfunc import md5_hexdigest from text_dedup.utils.hashfunc import sha1_hash +from text_dedup.utils.hashfunc import sha256 +from text_dedup.utils.hashfunc import sha256_digest +from text_dedup.utils.hashfunc import sha256_hexdigest +from text_dedup.utils.hashfunc import xxh3_16hash +from text_dedup.utils.hashfunc import xxh3_32hash +from text_dedup.utils.hashfunc import xxh3_64 +from text_dedup.utils.hashfunc import xxh3_64_digest +from text_dedup.utils.hashfunc import xxh3_128 +from text_dedup.utils.hashfunc import xxh3_128_digest from text_dedup.utils.hashfunc import xxh3_hash +from text_dedup.utils.inspect import random_samples +from text_dedup.utils.load import load_hf_dataset +from text_dedup.utils.memory import DisableReferenceCount +from text_dedup.utils.preprocess import news_copy_preprocessing +from text_dedup.utils.preprocess import normalize from text_dedup.utils.timer import Timer from text_dedup.utils.tokenization import ngrams from text_dedup.utils.union_find import UnionFind @@ -30,4 +50,26 @@ "UnionFind", "sha1_hash", "xxh3_hash", + "load_hf_dataset", + "DisableReferenceCount", + "random_samples", + "normalize", + "news_copy_preprocessing", + "INDEX_COLUMN", + "CLUSTER_COLUMN", + "md5", + "sha256", + "sha1_hash", + "xxh3_64", + "xxh3_64_digest", + "xxh3_128", + "xxh3_128_digest", + "xxh3_hash", + "xxh3_16hash", + "xxh3_32hash", + "optimal_param", + "md5_digest", + "md5_hexdigest", + "sha256_digest", + "sha256_hexdigest", ] diff --git a/text_dedup/utils/hashfunc.py b/text_dedup/utils/hashfunc.py index 244a45c..2cab4d0 100644 --- a/text_dedup/utils/hashfunc.py +++ b/text_dedup/utils/hashfunc.py @@ -255,4 +255,8 @@ def xxh3_hash(data: bytes, d: int = 32) -> int: "xxh3_hash", "xxh3_16hash", "xxh3_32hash", + "md5_digest", + "md5_hexdigest", + "sha256_digest", + "sha256_hexdigest", ] diff --git a/text_dedup/utils/load.py b/text_dedup/utils/load.py index 779f191..0715e42 100644 --- a/text_dedup/utils/load.py +++ b/text_dedup/utils/load.py @@ -2,10 +2,12 @@ from datasets import load_dataset from datasets import load_from_disk -from text_dedup.utils.args import IOArgs +from text_dedup.utils import INDEX_COLUMN +from text_dedup.utils import IOArgs +from text_dedup.utils import MetaArgs -def load_hf_dataset(io_args: IOArgs) -> Dataset: +def load_hf_dataset(io_args: IOArgs, meta_args: MetaArgs) -> Dataset: """ A simple wraper to load a huggingface dataset. @@ -13,6 +15,8 @@ def load_hf_dataset(io_args: IOArgs) -> Dataset: ---------- io_args : IOArgs The arguments for the dataset to load. + meta_args : MetaArgs + The arguments for the meta parameters of the dataset to load. Returns ------- @@ -34,5 +38,9 @@ def load_hf_dataset(io_args: IOArgs) -> Dataset: num_proc=io_args.num_proc, token=io_args.use_auth_token, ) - - return ds + ds = ds.map(lambda x, i: {INDEX_COLUMN: i}, with_indices=True, num_proc=io_args.num_proc) + id2id = None + if meta_args.idx_column is not None: + original_index = ds[meta_args.idx_column] + id2id = {idx: oid for idx, oid in zip(ds[INDEX_COLUMN], original_index)} + return ds, id2id diff --git a/text_dedup/utils/union_find.py b/text_dedup/utils/union_find.py index 99fa0cf..b306b57 100644 --- a/text_dedup/utils/union_find.py +++ b/text_dedup/utils/union_find.py @@ -1,7 +1,9 @@ #!/usr/bin/env python # @Date : 2022-12-26 15:37:44 # @Author : Chenghao Mou (mouchenghao@gmail.com) +import pickle # nosec from collections import Counter +from pathlib import Path class UnionFind: @@ -83,3 +85,14 @@ def union(self, x, y): def reset(self): self.parent = {} self.rank = Counter() + + def dump(self, path: str | Path, id2id=None): + if id2id is not None: + new_uf = UnionFind() + for i in self.parent: + new_uf.union(id2id[i], id2id[self.find(i)]) + else: + new_uf = self + + with open(path, "wb") as f: + pickle.dump(new_uf, f, protocol=pickle.HIGHEST_PROTOCOL)