Skip to content

Commit

Permalink
MinHash improvement using minhash_permuted (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
praateekmahajan authored Nov 20, 2024
1 parent 8408a7b commit 07e2a40
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 12 deletions.
27 changes: 20 additions & 7 deletions nemo_curator/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,31 @@
import sys

import dask
from packaging.version import parse as parseVersion
from packaging.version import parse as parse_version

try:
_dask_version = parseVersion(dask.__version__)
_dask_version = parse_version(dask.__version__)
except TypeError:
# When mocking with autodoc the dask version is not there
_dask_version = parseVersion("2024.06.0")
_dask_version = parse_version("2024.06.0")

try:
import cudf

CURRENT_CUDF_VERSION = parse_version(cudf.__version__)
except (ImportError, TypeError):
CURRENT_CUDF_VERSION = parse_version("24.10.0")

# TODO remove this once 24.12.0 becomes the base version of cudf in nemo-curator
MINHASH_PERMUTED_AVAILABLE = CURRENT_CUDF_VERSION >= parse_version("24.12.0") or (
CURRENT_CUDF_VERSION.is_prerelease
and CURRENT_CUDF_VERSION.base_version >= "24.12.0"
)

# TODO: remove when dask min version gets bumped
DASK_SHUFFLE_METHOD_ARG = _dask_version > parseVersion("2024.1.0")
DASK_P2P_ERROR = _dask_version < parseVersion("2023.10.0")
DASK_SHUFFLE_CAST_DTYPE = _dask_version > parseVersion("2023.12.0")
DASK_SHUFFLE_METHOD_ARG = _dask_version > parse_version("2024.1.0")
DASK_P2P_ERROR = _dask_version < parse_version("2023.10.0")
DASK_SHUFFLE_CAST_DTYPE = _dask_version > parse_version("2023.12.0")

# Query-planning check (and cache)
_DASK_QUERY_PLANNING_ENABLED = None
Expand All @@ -36,7 +49,7 @@ def query_planning_enabled():
global _DASK_QUERY_PLANNING_ENABLED

if _DASK_QUERY_PLANNING_ENABLED is None:
if _dask_version > parseVersion("2024.6.0"):
if _dask_version > parse_version("2024.6.0"):
import dask.dataframe as dd

_DASK_QUERY_PLANNING_ENABLED = dd.DASK_EXPR_ENABLED
Expand Down
76 changes: 71 additions & 5 deletions nemo_curator/modules/fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from dask.utils import M
from tqdm import tqdm

from nemo_curator._compat import MINHASH_PERMUTED_AVAILABLE
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.config import FuzzyDuplicatesConfig
Expand Down Expand Up @@ -99,7 +100,14 @@ def __init__(
"""
self.num_hashes = num_hashes
self.char_ngram = char_ngrams
self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)
if MINHASH_PERMUTED_AVAILABLE:
self.seeds = self.generate_hash_permutation_seeds(
bit_width=64 if use_64bit_hash else 32,
n_permutations=self.num_hashes,
seed=seed,
)
else:
self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)
self.minhash_method = self.minhash64 if use_64bit_hash else self.minhash32
self.id_field = id_field
self.text_field = text_field
Expand Down Expand Up @@ -127,6 +135,35 @@ def generate_seeds(self, n_seeds: int = 260, seed: int = 0) -> np.ndarray:
gen = np.random.RandomState(seed)
return gen.randint(0, 1e6, size=n_seeds)

def generate_hash_permutation_seeds(
self, bit_width: int, n_permutations: int = 260, seed: int = 0
) -> np.ndarray:
"""
Generate seeds for all minhash permutations based on the given seed.
"""
gen = np.random.RandomState(seed)

if bit_width == 32:
MERSENNE_PRIME = np.uint32((1 << 31) - 1)
dtype = np.uint32
elif bit_width == 64:
# For 64-bit, use a larger prime number suitable for 64-bit operations
MERSENNE_PRIME = np.uint64((1 << 61) - 1)
dtype = np.uint64
else:
raise ValueError("Unsupported bit width. Use either 32 or 64.")

return np.array(
[
(
gen.randint(1, MERSENNE_PRIME, dtype=dtype),
gen.randint(0, MERSENNE_PRIME, dtype=dtype),
)
for _ in range(n_permutations)
],
dtype=dtype,
)

def minhash32(
self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
) -> cudf.Series:
Expand All @@ -135,8 +172,23 @@ def minhash32(
"""
if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")
seeds = cudf.Series(seeds, dtype="uint32")
return ser.str.minhash(seeds=seeds, width=char_ngram)

if not MINHASH_PERMUTED_AVAILABLE:
warnings.warn(
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
"or later for improved performance. "
"Install the latest version of cuDF using `pip install curator[cuda12x_nightly]`",
category=FutureWarning,
)
seeds = cudf.Series(seeds, dtype="uint32")
return ser.str.minhash(seeds=seeds, width=char_ngram)
else:
seeds_a = cudf.Series(seeds[:, 0], dtype="uint32")
seeds_b = cudf.Series(seeds[:, 1], dtype="uint32")

return ser.str.minhash_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)

def minhash64(
self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
Expand All @@ -146,8 +198,22 @@ def minhash64(
"""
if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")
seeds = cudf.Series(seeds, dtype="uint64")
return ser.str.minhash64(seeds=seeds, width=char_ngram)
if not MINHASH_PERMUTED_AVAILABLE:
warnings.warn(
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
"or later for improved performance. "
"Install the latest version of cuDF using `pip install curator[cuda12x_nightly]`",
category=FutureWarning,
)
seeds = cudf.Series(seeds, dtype="uint64")
return ser.str.minhash64(seeds=seeds, width=char_ngram)
else:
seeds_a = cudf.Series(seeds[:, 0], dtype="uint64")
seeds_b = cudf.Series(seeds[:, 1], dtype="uint64")

return ser.str.minhash64_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)

def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]:
"""
Expand Down

0 comments on commit 07e2a40

Please sign in to comment.