Skip to content

Commit

Permalink
[REVIEW] Speedup Connected Components (NVIDIA#302)
Browse files Browse the repository at this point in the history
* Speedup fuzzy dedup by avoiding merge

Signed-off-by: Vibhu Jawa <[email protected]>

* Remove unused function

Signed-off-by: Vibhu Jawa <[email protected]>

* Clean up PR based on Praateeks reviews

Signed-off-by: Vibhu Jawa <[email protected]>

* style fixes

Signed-off-by: Vibhu Jawa <[email protected]>

* style fixes

Signed-off-by: Vibhu Jawa <[email protected]>

* Remove dangling print

Signed-off-by: Vibhu Jawa <[email protected]>

* Add handling for multiple columns

Signed-off-by: Vibhu Jawa <[email protected]>

* Nuking convert to strings

Signed-off-by: Vibhu Jawa <[email protected]>

* Nuking convert to strings

Signed-off-by: Vibhu Jawa <[email protected]>

* Verify it works on exp-01

Signed-off-by: Vibhu Jawa <[email protected]>

* Add dask profile options and add overwrite

Signed-off-by: Vibhu Jawa <[email protected]>

---------

Signed-off-by: Vibhu Jawa <[email protected]>
Signed-off-by: Vibhu Jawa <[email protected]>
  • Loading branch information
VibhuJawa authored Oct 29, 2024
1 parent 1cc6d7a commit 36fcf50
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 88 deletions.
127 changes: 41 additions & 86 deletions nemo_curator/modules/fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,6 @@ def __init__(
cache_dir=self.config.cache_dir,
jaccard_pairs_path=os.path.join(self.config.cache_dir, jaccard_pairs_fname),
id_column=self.config.id_field,
convert_str_ids=False,
jaccard_threshold=self.config.jaccard_threshold,
logger=self._logger,
profile_dir=self.config.profile_dir,
Expand Down Expand Up @@ -1405,7 +1404,6 @@ def __init__(
cache_dir: str,
jaccard_pairs_path: str,
id_column="id",
convert_str_ids=False,
jaccard_threshold: float = 0.8,
logger: Union[logging.LoggerAdapter, str] = "./",
profile_dir: Optional[str] = None,
Expand All @@ -1415,7 +1413,6 @@ def __init__(
self.id_column = id_column
self.left_id = f"{id_column}_x"
self.right_id = f"{id_column}_y"
self.convert_str_ids = convert_str_ids
self.jaccard_threshold = jaccard_threshold
self.profile_dir = profile_dir
if isinstance(logger, str):
Expand Down Expand Up @@ -1451,7 +1448,7 @@ def _run_connected_components(
self.profile_dir, "connected-components-run"
):

Comms.initialize(p2p=True)
Comms.initialize(p2p=False)
df = dask_cudf.read_parquet(
deduped_encoded_jaccard_path, blocksize="1GB", aggregate_files=True
)
Expand Down Expand Up @@ -1479,9 +1476,7 @@ def _run_connected_components(
labels_df = labels_df.merge(
result, left_on=["uid"], right_on=["vertex"], how="inner"
)
id_columns = (
["dataset_id", "doc_id"] if self.convert_str_ids else [self.id_column]
)
id_columns = [self.id_column]
labels_df = labels_df[id_columns + ["labels"]]
labels_df = labels_df.rename(columns={"labels": "group"})
labels_df = labels_df.persist()
Expand All @@ -1498,7 +1493,7 @@ def _run_connected_components(
assert num_nodes == len(labels_df)
# Ensure all docs in the same group are in the same partition
labels_df = labels_df.shuffle(on=["group"], ignore_index=True)
labels_df.to_parquet(output_path, write_index=False)
labels_df.to_parquet(output_path, write_index=False, overwrite=True)
Comms.destroy()
self._logger.info(
f"Time taken for Connected Components Run = {time.time() - t0}s and output written at {output_path}"
Expand Down Expand Up @@ -1566,20 +1561,12 @@ def _write_dedup_encoded_jaccard_pair(self, encoded_jaccard_pair_path):
transform_divisions=False,
align_dataframes=False,
)
ddf.to_parquet(output_path, write_index=False)
ddf.to_parquet(output_path, write_index=False, overwrite=True)
self._logger.info(
f"Time taken for Dedup Encoding Jaccard Pairs = {time.time() - t0}s and output written at {output_path}"
)
return output_path

def _convert_str_id_pair_to_int(self, df):
for id, tag in zip([self.left_id, self.right_id], ["x", "y"]):
dx = df[id].str.rsplit("-", n=1, expand=True)
df[f"dataset_id_{tag}"] = dx[0].astype("uint32").values
df[f"doc_id_{tag}"] = dx[1].astype("int64").values
df = df.drop(columns=[id])
return df

def _write_dedup_parsed_id(self):
dedup_parsed_id_path = f"{self.cache_dir}/dedup_parsed_id.parquet"
t0 = time.time()
Expand All @@ -1589,22 +1576,10 @@ def _write_dedup_parsed_id(self):
ddf = dask_cudf.read_parquet(
self.jaccard_pairs_path,
columns=[self.left_id, self.right_id],
blocksize="1GB",
blocksize="512MB",
aggregate_files=True,
)
id_columns = [self.id_column]
if self.convert_str_ids:
ddf = ddf.map_partitions(
self._convert_str_id_pair_to_int,
meta={
"dataset_id_x": "uint32",
"doc_id_x": "int64",
"dataset_id_y": "uint32",
"doc_id_y": "int64",
},
)
id_columns = ["dataset_id", "doc_id"]

unique_docs = ddf.map_partitions(
ConnectedComponents._get_unique_ids_per_partition, id_columns=id_columns
)
Expand All @@ -1615,7 +1590,9 @@ def _write_dedup_parsed_id(self):
unique_docs["uid"] = np.uint64(1)
unique_docs["uid"] = unique_docs["uid"].cumsum()
unique_docs["uid"] = unique_docs["uid"] - 1
unique_docs.to_parquet(dedup_parsed_id_path, write_index=False)
unique_docs.to_parquet(
dedup_parsed_id_path, write_index=False, overwrite=True
)
self._logger.info(
f"Time taken for Dedup Parsed Id = {time.time() - t0}s and output written at {dedup_parsed_id_path}"
)
Expand All @@ -1630,73 +1607,51 @@ def _write_encoded_jaccard_pair(self, dedup_parsed_id_path):
ddf_id = dask_cudf.read_parquet(
dedup_parsed_id_path, blocksize="2GB", aggregate_files=True
)
ddf_id = ddf_id.persist()
len(ddf_id)
ddf = dask_cudf.read_parquet(
self.jaccard_pairs_path,
blocksize="256MB",
blocksize="1GB",
aggregate_files=True,
)
id_columns = [self.id_column]
if self.convert_str_ids:
ddf = ddf.map_partitions(
self._convert_str_id_pair_to_int,
meta={
"jaccard": "float32",
"dataset_id_x": "uint32",
"doc_id_x": "int64",
"dataset_id_y": "uint32",
"doc_id_y": "int64",
},
)
id_columns = ["dataset_id", "doc_id"]

num_workers = get_num_workers(get_current_client())
self._batched_merge_and_write(
self._merge_and_write(
ddf=ddf,
ddf_id=ddf_id,
output_path=output_path,
id_columns=id_columns,
batch_size=num_workers,
id_column=self.id_column,
)
self._logger.info(
f"Time taken for Encoding Jaccard Pairs = {time.time() - t0}s and output written at {output_path}"
)
return output_path

def _batched_merge_and_write(
self, ddf, ddf_id, output_path, id_columns, batch_size=32
):
total_batches = (ddf.npartitions + batch_size - 1) // batch_size
for batch_id, offset in enumerate(range(0, ddf.npartitions, batch_size)):
st = time.time()
subset_ddf = ddf.partitions[offset : offset + batch_size]
for tag in ["x", "y"]:
pair_ids = []
for id_col in id_columns:
pair_ids.append(f"{id_col}_{tag}")
subset_ddf = subset_ddf.merge(
ddf_id,
left_on=pair_ids,
right_on=id_columns,
how="inner",
broadcast=True,
)
subset_ddf = subset_ddf.drop(
columns=pair_ids,
)
subset_ddf = subset_ddf.rename(
columns={"uid": f"{self.id_column}_{tag}"}
)

subset_ddf = subset_ddf[[self.left_id, self.right_id, "jaccard"]]
output_batch_path = os.path.join(output_path, f"{batch_id}.parquet")
subset_ddf.to_parquet(output_batch_path, write_index=False)

et = time.time()
print(
f"batch_id = {batch_id}/{total_batches}, time = {et - st}", flush=True
def _merge_and_write(
self,
ddf: dask_cudf.DataFrame,
ddf_id: dask_cudf.DataFrame,
output_path: str,
id_column: str,
) -> None:
st = time.time()
# Ensure 'id_columns' is a list
ddf_id = ddf_id.set_index(id_column)
for tag in ["x", "y"]:
pair_id = f"{id_column}_{tag}"
# Merge 'ddf' with 'ddf_id' to map ids to uids
ddf = ddf.merge(
ddf_id,
left_on=pair_id,
right_index=True,
how="inner",
broadcast=True,
)
ddf = ddf.drop(columns=pair_id)
ddf = ddf.rename(columns={"uid": f"{self.id_column}_{tag}"})
ddf = ddf[[self.left_id, self.right_id, "jaccard"]]
ddf.to_parquet(output_path, write_index=False, overwrite=True)

et = time.time()
self._logger.info(
f"Time taken for merge and write = {et - st}s and output written at {output_path}"
)

@staticmethod
def _get_unique_ids_per_partition(df, id_columns):
Expand All @@ -1706,11 +1661,11 @@ def _get_unique_ids_per_partition(df, id_columns):
for id_col in id_columns:
cols_to_drop.append(f"{id_col}_{tag}")

subset_df = df[cols_to_drop].drop_duplicates()
subset_df = df[cols_to_drop].drop_duplicates(ignore_index=True)
subset_df = subset_df.rename(
columns={f"{id_col}_{tag}": f"{id_col}" for id_col in id_columns}
)
unique_df_ls.append(subset_df)
unique_df = cudf.concat(unique_df_ls, ignore_index=True)
unique_df = unique_df.drop_duplicates()
unique_df = unique_df.drop_duplicates(ignore_index=True)
return unique_df
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ def main(args):
st = time.time()
output_path = os.path.join(args.output_dir, "connected_components.parquet")
args.enable_spilling = True

client = get_client(**ArgumentHelper.parse_client_args(args))

components_stage = ConnectedComponents(
cache_dir=args.cache_dir,
jaccard_pairs_path=args.jaccard_pairs_path,
id_column=args.input_json_id_field,
convert_str_ids=True,
jaccard_threshold=args.jaccard_threshold,
logger=args.log_dir,
profile_dir=args.profile_path,
)
components_stage.cc_workflow(output_path=output_path)
print(f"All done in {time.time()-st:.1f} seconds")
Expand Down

0 comments on commit 36fcf50

Please sign in to comment.