Skip to content

Commit

Permalink
Retiring text_bytes_aware_shuffle to use shuffle directly (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
praateekmahajan authored Nov 20, 2024
1 parent 07e2a40 commit 914e61b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 117 deletions.
46 changes: 10 additions & 36 deletions nemo_curator/modules/fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@
build_partition,
get_agg_text_bytes_df,
)
from nemo_curator.utils.fuzzy_dedup_utils.shuffle_utils import (
text_bytes_aware_shuffle,
write_partitioned_file,
)
from nemo_curator.utils.fuzzy_dedup_utils.shuffle_utils import write_partitioned_file


class MinHash:
Expand Down Expand Up @@ -437,7 +434,7 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:

write_path = os.path.join(self.cache_dir, "_buckets.parquet")
t0 = time.time()
with performance_report_if_with_ts_suffix(self.profile_dir, f"lsh-profile"):
with performance_report_if_with_ts_suffix(self.profile_dir, "lsh-profile"):
self.lsh(write_path=write_path, df=df)
self._logger.info(
f"Time taken for LSH = {time.time() - t0}s and output written at {write_path}"
Expand Down Expand Up @@ -570,7 +567,7 @@ def __call__(self, dataset: DocumentDataset):
)
with performance_report_if_with_ts_suffix(
self.config.profile_dir,
f"map_buckets",
"map_buckets",
):
ddf_mapped_buckets_w_anchors = (
self.map_buckets.map_buckets_with_anchors(
Expand Down Expand Up @@ -1060,6 +1057,7 @@ def shuffle_docs_on_buckets(
bucket_parts_per_worker: int = 8,
partition_on: str = "_output_partition_id",
):

ddf_anchor_docs_with_bk, bk_mapping = aggregated_anchor_docs_with_bk_read(
path=bucket_w_anchors_path,
blocksize=bucket_mapping_df_blocksize,
Expand Down Expand Up @@ -1206,36 +1204,12 @@ def _batched_merge_and_write(
subset_text_df = left_df_use.partitions[
text_part_offset:end_text_offset
]

try:
# NOTE: If we have more text-df partitions than bucket-map
# partitions, we are more likely to see an OverflowError

subset_merged_df = merge_left_to_shuffled_right(
subset_text_df,
subset_bucket_df,
merge_on,
)
# Returns a dataframe or None (when the merge is empty)
output_df = text_bytes_aware_shuffle(
df=subset_merged_df,
partition_on=partition_on,
text_column=self.text_field,
num_workers=num_workers,
)
except OverflowError as err:
# We encountered an overflow error!
# Let's try again with less text data
parts_per_text_batch_retry = int(parts_per_text_batch_use / 2)
if parts_per_text_batch_retry < 1:
raise err
print(
f"\nWe encountered an OverflowError and will retry "
f"the current batch with {parts_per_text_batch_retry} "
f"text partitions instead of {parts_per_text_batch_use}.",
flush=True,
)
continue
subset_merged_df = merge_left_to_shuffled_right(
subset_text_df,
subset_bucket_df,
merge_on,
)
output_df = subset_merged_df.shuffle(on=partition_on)

if self.int_to_str_id is not None and output_df is not None:
output_df = output_df.map_partitions(
Expand Down
1 change: 1 addition & 0 deletions nemo_curator/utils/fuzzy_dedup_utils/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def aggregated_anchor_docs_with_bk_read(path, blocksize):
sorted(glob(f"{path}/*.parquet"), key=natural_sort_key),
format="parquet",
)
# create chunks of files to which are less than blocksize
chunks = chunk_files(ds.get_fragments(), blocksize)

# Record mapping between file indices and partition indices.
Expand Down
82 changes: 1 addition & 81 deletions nemo_curator/utils/fuzzy_dedup_utils/shuffle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import cudf
import dask_cuda
Expand All @@ -21,10 +20,7 @@
from packaging.version import Version

from nemo_curator._compat import query_planning_enabled
from nemo_curator.utils.fuzzy_dedup_utils.output_map_utils import (
build_partition,
get_agg_text_bytes_df,
)
from nemo_curator.utils.fuzzy_dedup_utils.output_map_utils import build_partition

dask_cuda_version = Version(dask_cuda.__version__)
USE_EXCOMMS = (
Expand Down Expand Up @@ -126,79 +122,3 @@ def get_shuffle_part_ids_df(
df[partition_on] = agg_df[partition_on]
df[output_col] = output_ar
return df


def get_shuffle_partition_info(
df,
partition_on,
output_column,
text_column,
bytes_column="_text_bytes",
num_workers=None,
):
df[bytes_column] = df[text_column].map_partitions(lambda s: s.str.byte_count())
agg_df, _ = get_agg_text_bytes_df(
df, agg_column=partition_on, bytes_column=bytes_column, n_partitions=1
)
del df

agg_df = agg_df.reset_index(drop=True)
shuffle_part_ids = agg_df.map_partitions(
get_shuffle_part_ids_df,
partition_on,
size_col=bytes_column,
num_workers=num_workers,
output_col=output_column,
).persist()
return shuffle_part_ids


def text_bytes_aware_shuffle(
df,
partition_on: str,
text_column: str,
num_workers: Optional[int] = None,
):
"""
This shuffle takes into account the text bytes of each partition
and tries to make sure that the output partitions do not exceed
the char limit of cuDF
Args:
df: dask_cudf dataframe
partition_on: column name to partition on
text_column: column name for the text data
Returns:
dask_cudf dataframe with _partitions columns or None if `df` is empty after the merge
"""
print("Starting text bytes aware shuffle", flush=True)
output_col = "_partitions"

df = df.persist()
if len(df) == 0:
return None
shuffle_part_ids = get_shuffle_partition_info(
df=df,
partition_on=partition_on,
num_workers=num_workers,
output_column=output_col,
text_column=text_column,
)
n_output_partitions = shuffle_part_ids[output_col].max().compute() + 1
n_output_partitions = int(n_output_partitions)
df = df.merge(shuffle_part_ids, on=partition_on, how="inner").persist()

df = (
rearange_by_column_direct(
df,
col=output_col,
npartitions=n_output_partitions,
ignore_index=True,
excomms_default=True,
)
.drop(columns=[output_col])
.persist()
)
print(f"Will write {len(df)} rows to disk", flush=True)
return df

0 comments on commit 914e61b

Please sign in to comment.