From 36fcf50cee12ccd3e85b204f7ef8c4f62c84aa51 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Tue, 29 Oct 2024 16:28:03 -0700 Subject: [PATCH] [REVIEW] Speedup Connected Components (#302) * Speedup fuzzy dedup by avoiding merge Signed-off-by: Vibhu Jawa * Remove unused function Signed-off-by: Vibhu Jawa * Clean up PR based on Praateeks reviews Signed-off-by: Vibhu Jawa * style fixes Signed-off-by: Vibhu Jawa * style fixes Signed-off-by: Vibhu Jawa * Remove dangling print Signed-off-by: Vibhu Jawa * Add handling for multiple columns Signed-off-by: Vibhu Jawa * Nuking convert to strings Signed-off-by: Vibhu Jawa * Nuking convert to strings Signed-off-by: Vibhu Jawa * Verify it works on exp-01 Signed-off-by: Vibhu Jawa * Add dask profile options and add overwrite Signed-off-by: Vibhu Jawa --------- Signed-off-by: Vibhu Jawa Signed-off-by: Vibhu Jawa --- nemo_curator/modules/fuzzy_dedup.py | 127 ++++++------------ .../connected_components.py | 4 +- 2 files changed, 43 insertions(+), 88 deletions(-) diff --git a/nemo_curator/modules/fuzzy_dedup.py b/nemo_curator/modules/fuzzy_dedup.py index 63576516c..f9c80ce71 100644 --- a/nemo_curator/modules/fuzzy_dedup.py +++ b/nemo_curator/modules/fuzzy_dedup.py @@ -469,7 +469,6 @@ def __init__( cache_dir=self.config.cache_dir, jaccard_pairs_path=os.path.join(self.config.cache_dir, jaccard_pairs_fname), id_column=self.config.id_field, - convert_str_ids=False, jaccard_threshold=self.config.jaccard_threshold, logger=self._logger, profile_dir=self.config.profile_dir, @@ -1405,7 +1404,6 @@ def __init__( cache_dir: str, jaccard_pairs_path: str, id_column="id", - convert_str_ids=False, jaccard_threshold: float = 0.8, logger: Union[logging.LoggerAdapter, str] = "./", profile_dir: Optional[str] = None, @@ -1415,7 +1413,6 @@ def __init__( self.id_column = id_column self.left_id = f"{id_column}_x" self.right_id = f"{id_column}_y" - self.convert_str_ids = convert_str_ids self.jaccard_threshold = jaccard_threshold self.profile_dir = profile_dir if isinstance(logger, str): @@ -1451,7 +1448,7 @@ def _run_connected_components( self.profile_dir, "connected-components-run" ): - Comms.initialize(p2p=True) + Comms.initialize(p2p=False) df = dask_cudf.read_parquet( deduped_encoded_jaccard_path, blocksize="1GB", aggregate_files=True ) @@ -1479,9 +1476,7 @@ def _run_connected_components( labels_df = labels_df.merge( result, left_on=["uid"], right_on=["vertex"], how="inner" ) - id_columns = ( - ["dataset_id", "doc_id"] if self.convert_str_ids else [self.id_column] - ) + id_columns = [self.id_column] labels_df = labels_df[id_columns + ["labels"]] labels_df = labels_df.rename(columns={"labels": "group"}) labels_df = labels_df.persist() @@ -1498,7 +1493,7 @@ def _run_connected_components( assert num_nodes == len(labels_df) # Ensure all docs in the same group are in the same partition labels_df = labels_df.shuffle(on=["group"], ignore_index=True) - labels_df.to_parquet(output_path, write_index=False) + labels_df.to_parquet(output_path, write_index=False, overwrite=True) Comms.destroy() self._logger.info( f"Time taken for Connected Components Run = {time.time() - t0}s and output written at {output_path}" @@ -1566,20 +1561,12 @@ def _write_dedup_encoded_jaccard_pair(self, encoded_jaccard_pair_path): transform_divisions=False, align_dataframes=False, ) - ddf.to_parquet(output_path, write_index=False) + ddf.to_parquet(output_path, write_index=False, overwrite=True) self._logger.info( f"Time taken for Dedup Encoding Jaccard Pairs = {time.time() - t0}s and output written at {output_path}" ) return output_path - def _convert_str_id_pair_to_int(self, df): - for id, tag in zip([self.left_id, self.right_id], ["x", "y"]): - dx = df[id].str.rsplit("-", n=1, expand=True) - df[f"dataset_id_{tag}"] = dx[0].astype("uint32").values - df[f"doc_id_{tag}"] = dx[1].astype("int64").values - df = df.drop(columns=[id]) - return df - def _write_dedup_parsed_id(self): dedup_parsed_id_path = f"{self.cache_dir}/dedup_parsed_id.parquet" t0 = time.time() @@ -1589,22 +1576,10 @@ def _write_dedup_parsed_id(self): ddf = dask_cudf.read_parquet( self.jaccard_pairs_path, columns=[self.left_id, self.right_id], - blocksize="1GB", + blocksize="512MB", aggregate_files=True, ) id_columns = [self.id_column] - if self.convert_str_ids: - ddf = ddf.map_partitions( - self._convert_str_id_pair_to_int, - meta={ - "dataset_id_x": "uint32", - "doc_id_x": "int64", - "dataset_id_y": "uint32", - "doc_id_y": "int64", - }, - ) - id_columns = ["dataset_id", "doc_id"] - unique_docs = ddf.map_partitions( ConnectedComponents._get_unique_ids_per_partition, id_columns=id_columns ) @@ -1615,7 +1590,9 @@ def _write_dedup_parsed_id(self): unique_docs["uid"] = np.uint64(1) unique_docs["uid"] = unique_docs["uid"].cumsum() unique_docs["uid"] = unique_docs["uid"] - 1 - unique_docs.to_parquet(dedup_parsed_id_path, write_index=False) + unique_docs.to_parquet( + dedup_parsed_id_path, write_index=False, overwrite=True + ) self._logger.info( f"Time taken for Dedup Parsed Id = {time.time() - t0}s and output written at {dedup_parsed_id_path}" ) @@ -1630,73 +1607,51 @@ def _write_encoded_jaccard_pair(self, dedup_parsed_id_path): ddf_id = dask_cudf.read_parquet( dedup_parsed_id_path, blocksize="2GB", aggregate_files=True ) - ddf_id = ddf_id.persist() - len(ddf_id) ddf = dask_cudf.read_parquet( self.jaccard_pairs_path, - blocksize="256MB", + blocksize="1GB", aggregate_files=True, ) - id_columns = [self.id_column] - if self.convert_str_ids: - ddf = ddf.map_partitions( - self._convert_str_id_pair_to_int, - meta={ - "jaccard": "float32", - "dataset_id_x": "uint32", - "doc_id_x": "int64", - "dataset_id_y": "uint32", - "doc_id_y": "int64", - }, - ) - id_columns = ["dataset_id", "doc_id"] - - num_workers = get_num_workers(get_current_client()) - self._batched_merge_and_write( + self._merge_and_write( ddf=ddf, ddf_id=ddf_id, output_path=output_path, - id_columns=id_columns, - batch_size=num_workers, + id_column=self.id_column, ) self._logger.info( f"Time taken for Encoding Jaccard Pairs = {time.time() - t0}s and output written at {output_path}" ) return output_path - def _batched_merge_and_write( - self, ddf, ddf_id, output_path, id_columns, batch_size=32 - ): - total_batches = (ddf.npartitions + batch_size - 1) // batch_size - for batch_id, offset in enumerate(range(0, ddf.npartitions, batch_size)): - st = time.time() - subset_ddf = ddf.partitions[offset : offset + batch_size] - for tag in ["x", "y"]: - pair_ids = [] - for id_col in id_columns: - pair_ids.append(f"{id_col}_{tag}") - subset_ddf = subset_ddf.merge( - ddf_id, - left_on=pair_ids, - right_on=id_columns, - how="inner", - broadcast=True, - ) - subset_ddf = subset_ddf.drop( - columns=pair_ids, - ) - subset_ddf = subset_ddf.rename( - columns={"uid": f"{self.id_column}_{tag}"} - ) - - subset_ddf = subset_ddf[[self.left_id, self.right_id, "jaccard"]] - output_batch_path = os.path.join(output_path, f"{batch_id}.parquet") - subset_ddf.to_parquet(output_batch_path, write_index=False) - - et = time.time() - print( - f"batch_id = {batch_id}/{total_batches}, time = {et - st}", flush=True + def _merge_and_write( + self, + ddf: dask_cudf.DataFrame, + ddf_id: dask_cudf.DataFrame, + output_path: str, + id_column: str, + ) -> None: + st = time.time() + # Ensure 'id_columns' is a list + ddf_id = ddf_id.set_index(id_column) + for tag in ["x", "y"]: + pair_id = f"{id_column}_{tag}" + # Merge 'ddf' with 'ddf_id' to map ids to uids + ddf = ddf.merge( + ddf_id, + left_on=pair_id, + right_index=True, + how="inner", + broadcast=True, ) + ddf = ddf.drop(columns=pair_id) + ddf = ddf.rename(columns={"uid": f"{self.id_column}_{tag}"}) + ddf = ddf[[self.left_id, self.right_id, "jaccard"]] + ddf.to_parquet(output_path, write_index=False, overwrite=True) + + et = time.time() + self._logger.info( + f"Time taken for merge and write = {et - st}s and output written at {output_path}" + ) @staticmethod def _get_unique_ids_per_partition(df, id_columns): @@ -1706,11 +1661,11 @@ def _get_unique_ids_per_partition(df, id_columns): for id_col in id_columns: cols_to_drop.append(f"{id_col}_{tag}") - subset_df = df[cols_to_drop].drop_duplicates() + subset_df = df[cols_to_drop].drop_duplicates(ignore_index=True) subset_df = subset_df.rename( columns={f"{id_col}_{tag}": f"{id_col}" for id_col in id_columns} ) unique_df_ls.append(subset_df) unique_df = cudf.concat(unique_df_ls, ignore_index=True) - unique_df = unique_df.drop_duplicates() + unique_df = unique_df.drop_duplicates(ignore_index=True) return unique_df diff --git a/nemo_curator/scripts/fuzzy_deduplication/connected_components.py b/nemo_curator/scripts/fuzzy_deduplication/connected_components.py index e6353b786..33e37f105 100644 --- a/nemo_curator/scripts/fuzzy_deduplication/connected_components.py +++ b/nemo_curator/scripts/fuzzy_deduplication/connected_components.py @@ -32,15 +32,15 @@ def main(args): st = time.time() output_path = os.path.join(args.output_dir, "connected_components.parquet") args.enable_spilling = True - client = get_client(**ArgumentHelper.parse_client_args(args)) components_stage = ConnectedComponents( cache_dir=args.cache_dir, jaccard_pairs_path=args.jaccard_pairs_path, id_column=args.input_json_id_field, - convert_str_ids=True, jaccard_threshold=args.jaccard_threshold, + logger=args.log_dir, + profile_dir=args.profile_path, ) components_stage.cc_workflow(output_path=output_path) print(f"All done in {time.time()-st:.1f} seconds")