Skip to content

Commit

Permalink
Allow users to write to single file (#383)
Browse files Browse the repository at this point in the history
* jsonl support

Signed-off-by: Sarah Yurick <[email protected]>

* update param name

Signed-off-by: Sarah Yurick <[email protected]>

* run black

Signed-off-by: Sarah Yurick <[email protected]>

* update pytest

Signed-off-by: Sarah Yurick <[email protected]>

* add parquet

Signed-off-by: Sarah Yurick <[email protected]>

* add npartitions check

Signed-off-by: Sarah Yurick <[email protected]>

* add compute and repartition functions

Signed-off-by: Sarah Yurick <[email protected]>

* add runtimeerror

Signed-off-by: Sarah Yurick <[email protected]>

* remove compute function

Signed-off-by: Sarah Yurick <[email protected]>

* address broken pytests

Signed-off-by: Sarah Yurick <[email protected]>

* run black

Signed-off-by: Sarah Yurick <[email protected]>

* add ayush's suggestions

Signed-off-by: Sarah Yurick <[email protected]>

* run isort

Signed-off-by: Sarah Yurick <[email protected]>

---------

Signed-off-by: Sarah Yurick <[email protected]>
Signed-off-by: Sarah Yurick <[email protected]>
  • Loading branch information
sarahyurick authored Dec 13, 2024
1 parent f56e924 commit 079d46f
Show file tree
Hide file tree
Showing 16 changed files with 80 additions and 55 deletions.
2 changes: 1 addition & 1 deletion examples/classifiers/aegis_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def main(args):
)
result_dataset = safety_classifier(dataset=input_dataset)

result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)
result_dataset.to_json(output_path=output_file_path, write_to_filename=True)

global_et = time.time()
print(
Expand Down
2 changes: 1 addition & 1 deletion examples/classifiers/domain_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def main(args):
domain_classifier = DomainClassifier(filter_by=["Games", "Sports"])
result_dataset = domain_classifier(dataset=input_dataset)

result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)
result_dataset.to_json(output_path=output_file_path, write_to_filename=True)

global_et = time.time()
print(
Expand Down
2 changes: 1 addition & 1 deletion examples/classifiers/fineweb_edu_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def main(args):

fineweb_edu_classifier = FineWebEduClassifier()
result_dataset = fineweb_edu_classifier(dataset=input_dataset)
result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)
result_dataset.to_json(output_path=output_file_path, write_to_filename=True)

global_et = time.time()
print(
Expand Down
2 changes: 1 addition & 1 deletion examples/classifiers/quality_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def main(args):
quality_classifier = QualityClassifier(filter_by=["High", "Medium"])
result_dataset = quality_classifier(dataset=input_dataset)

result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)
result_dataset.to_json(output_path=output_file_path, write_to_filename=True)

global_et = time.time()
print(
Expand Down
2 changes: 1 addition & 1 deletion examples/translation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def main(args):
)
result_dataset = translator_model(dataset=input_dataset)

result_dataset.to_json(output_file_dir=args.output_data_dir, write_to_filename=True)
result_dataset.to_json(output_path=args.output_data_dir, write_to_filename=True)
print(f"Total time taken for translation: {time.time()-st} seconds", flush=True)
client.close()

Expand Down
15 changes: 10 additions & 5 deletions nemo_curator/datasets/doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from functools import wraps
from typing import Any, List, Literal, Optional, Union

import dask.dataframe as dd
Expand All @@ -37,6 +38,10 @@ def __len__(self) -> int:
def persist(self) -> "DocumentDataset":
return DocumentDataset(self.df.persist())

@wraps(dd.DataFrame.repartition)
def repartition(self, *args, **kwargs) -> "DocumentDataset":
return self.__class__(self.df.repartition(*args, **kwargs))

def head(self, n: int = 5) -> Any:
return self.df.head(n)

Expand Down Expand Up @@ -146,7 +151,7 @@ def read_pickle(

def to_json(
self,
output_file_dir: str,
output_path: str,
write_to_filename: bool = False,
keep_filename_column: bool = False,
):
Expand All @@ -156,15 +161,15 @@ def to_json(
"""
write_to_disk(
df=self.df,
output_file_dir=output_file_dir,
output_path=output_path,
write_to_filename=write_to_filename,
keep_filename_column=keep_filename_column,
output_type="jsonl",
)

def to_parquet(
self,
output_file_dir: str,
output_path: str,
write_to_filename: bool = False,
keep_filename_column: bool = False,
):
Expand All @@ -174,15 +179,15 @@ def to_parquet(
"""
write_to_disk(
df=self.df,
output_file_dir=output_file_dir,
output_path=output_path,
write_to_filename=write_to_filename,
keep_filename_column=keep_filename_column,
output_type="parquet",
)

def to_pickle(
self,
output_file_dir: str,
output_path: str,
write_to_filename: bool = False,
):
raise NotImplementedError("DocumentDataset does not support to_pickle yet")
Expand Down
2 changes: 1 addition & 1 deletion nemo_curator/sample_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def sample_dataframe(df, num_samples):
sampled_df = sample_dataframe(df, num_samples=args.num_samples)
write_to_disk(
df=sampled_df,
output_file_dir=args.output_file_path,
output_path=args.output_file_path,
write_to_filename=True,
)
et = time.time()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main():

write_to_disk(
df=df,
output_file_dir=args.output_data_dir,
output_path=args.output_data_dir,
write_to_filename=add_filename,
output_type=args.output_file_type,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main():

write_to_disk(
df=df,
output_file_dir=args.output_data_dir,
output_path=args.output_data_dir,
write_to_filename=add_filename,
output_type=args.output_file_type,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main():

write_to_disk(
df=df,
output_file_dir=args.output_data_dir,
output_path=args.output_data_dir,
write_to_filename=add_filename,
output_type=args.output_file_type,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main():

write_to_disk(
df=df,
output_file_dir=args.output_data_dir,
output_path=args.output_data_dir,
write_to_filename=add_filename,
output_type=args.output_file_type,
)
Expand Down
80 changes: 50 additions & 30 deletions nemo_curator/utils/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def _single_partition_write_to_simple_bitext(


def _merge_tmp_simple_bitext_partitions(tmp_output_dir: str, output_dir: str):
"""Merge partitions of simple bitext files in `tmp_output_dir` into files at `output_file_dir`.
"""Merge partitions of simple bitext files in `tmp_output_dir` into files at `output_dir`.
Args:
tmp_output_dir (str): temporary directory that has all the simple bitext output partitions,
Expand Down Expand Up @@ -675,7 +675,7 @@ def _merge_tmp_simple_bitext_partitions(tmp_output_dir: str, output_dir: str):

def write_to_disk(
df,
output_file_dir: str,
output_path: str,
write_to_filename: bool = False,
keep_filename_column: bool = False,
output_type: str = "jsonl",
Expand All @@ -687,15 +687,30 @@ def write_to_disk(
Args:
df: A Dask DataFrame.
output_file_dir: The output file path.
output_path: The output file path.
write_to_filename: Boolean representing whether to write the filename using the "filename" column.
keep_filename_column: Boolean representing whether to keep or drop the "filename" column, if it exists.
output_type: The type of output file to write. Can be "jsonl" or "parquet".
"""
if write_to_filename and "filename" not in df.columns:

# output_path is a file name
if isinstance(output_path, str) and output_path.endswith(".jsonl"):
if df.npartitions == 1:
df.map_partitions(
_write_to_jsonl_or_parquet, output_path, output_type
).compute()
return
else:
raise RuntimeError(
"Could not write multi-partition DataFrame to a single JSONL file. "
"Please specify a directory output path or repartition the DataFrame."
)

# output_path is a directory
elif write_to_filename and "filename" not in df.columns:
raise ValueError(
"write_using_filename is True but no filename column found in df"
"write_using_filename is True but no filename column found in DataFrame"
)

if is_cudf_type(df):
Expand All @@ -705,42 +720,33 @@ def write_to_disk(
else:
output_meta = pd.Series([True], dtype="bool")

# output_path is a directory
if write_to_filename and output_type != "bitext":
os.makedirs(output_file_dir, exist_ok=True)
os.makedirs(output_path, exist_ok=True)
output = df.map_partitions(
single_partition_write_with_filename,
output_file_dir,
output_path,
keep_filename_column=keep_filename_column,
output_type=output_type,
meta=output_meta,
enforce_metadata=False,
)
output = output.compute()

# output_path is a directory
else:
if output_type == "jsonl":
if is_cudf_type(df):
# See open issue here: https://github.com/rapidsai/cudf/issues/15211
# df.to_json(output_file_dir, orient="records", lines=True, engine="cudf", force_ascii=False)
df.to_json(
output_file_dir, orient="records", lines=True, force_ascii=False
)
else:
df.to_json(
output_file_dir, orient="records", lines=True, force_ascii=False
)
elif output_type == "parquet":
df.to_parquet(output_file_dir, write_index=False)
if output_type == "jsonl" or output_type == "parquet":
_write_to_jsonl_or_parquet(df, output_path, output_type)
elif output_type == "bitext":
if write_to_filename:
os.makedirs(output_file_dir, exist_ok=True)
tmp_output_file_dir = os.path.join(output_file_dir, ".tmp")
os.makedirs(output_path, exist_ok=True)
tmp_output_file_dir = os.path.join(output_path, ".tmp")
os.makedirs(tmp_output_file_dir, exist_ok=True)
file_name = os.path.basename(list(df.filename.unique())[0])
else:
tmp_output_file_dir = os.path.join(output_file_dir, ".tmp")
tmp_output_file_dir = os.path.join(output_path, ".tmp")
os.makedirs(tmp_output_file_dir, exist_ok=True)
file_name = os.path.basename(output_file_dir)
file_name = os.path.basename(output_path)

output = df.map_partitions(
_single_partition_write_to_simple_bitext,
Expand All @@ -751,17 +757,31 @@ def write_to_disk(
output = output.compute()
_merge_tmp_simple_bitext_partitions(
tmp_output_file_dir,
(
output_file_dir
if write_to_filename
else os.path.dirname(output_file_dir)
),
(output_path if write_to_filename else os.path.dirname(output_path)),
)
shutil.rmtree(tmp_output_file_dir)
else:
raise ValueError(f"Unknown output type: {output_type}")

print(f"Writing to disk complete for {df.npartitions} partitions", flush=True)
print(f"Writing to disk complete for {df.npartitions} partition(s)", flush=True)


def _write_to_jsonl_or_parquet(
df,
output_path: str,
output_type: Literal["jsonl", "parquet"] = "jsonl",
):
if output_type == "jsonl":
if is_cudf_type(df):
# See open issue here: https://github.com/rapidsai/cudf/issues/15211
# df.to_json(output_path, orient="records", lines=True, engine="cudf", force_ascii=False)
df.to_json(output_path, orient="records", lines=True, force_ascii=False)
else:
df.to_json(output_path, orient="records", lines=True, force_ascii=False)
elif output_type == "parquet":
df.to_parquet(output_path, write_index=False)
else:
raise ValueError(f"Unknown output type: {output_type}")


def load_object_on_worker(attr, load_object_function, load_object_kwargs):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def test_multifile_multi_partition(self, tmp_path, file_ext, read_f):
ddf["filename"] = ddf["filename"] + f".{file_ext}"
write_to_disk(
df=ddf,
output_file_dir=tmp_path / file_ext,
output_path=tmp_path / file_ext,
write_to_filename=True,
output_type=file_ext,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_separate_by_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _write_data(num_files, file_ext):
df = dd.concat(dfs)
write_to_disk(
df=df,
output_file_dir=str(out_path),
output_path=str(out_path),
write_to_filename=True,
output_type=file_ext,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -175,7 +175,7 @@
"%%time\n",
"\n",
"result_dataset = classifier(dataset=input_dataset)\n",
"result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=write_to_filename)"
"result_dataset.to_json(output_path=output_file_path, write_to_filename=write_to_filename)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -662,7 +662,7 @@
"id_dataset = add_id(dataset)\n",
"\n",
"# Save the dataset with added IDs to disk\n",
"write_to_disk(id_dataset.df, output_file_dir=added_id_output_path, write_to_filename=True, output_type=\"parquet\")"
"write_to_disk(id_dataset.df, output_path=added_id_output_path, write_to_filename=True, output_type=\"parquet\")"
]
},
{
Expand Down Expand Up @@ -848,7 +848,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -879,7 +879,7 @@
"]\n",
"\n",
"# Save the final deduplicated dataset\n",
"write_to_disk(result, output_file_dir=deduped_output_dir, write_to_filename=True, output_type=\"parquet\")"
"write_to_disk(result, output_path=deduped_output_dir, write_to_filename=True, output_type=\"parquet\")"
]
},
{
Expand Down Expand Up @@ -1196,7 +1196,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -1222,7 +1222,7 @@
"filtered_dataset = filter_pipeline(target_dataset)\n",
"\n",
"# Save the filtered dataset\n",
"write_to_disk(filtered_dataset.df, output_file_dir=CF_output_path, write_to_filename=True, output_type=\"parquet\")"
"write_to_disk(filtered_dataset.df, output_path=CF_output_path, write_to_filename=True, output_type=\"parquet\")"
]
},
{
Expand Down

0 comments on commit 079d46f

Please sign in to comment.