Skip to content

Commit

Permalink
Add blend example and tests
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Wolf <[email protected]>
  • Loading branch information
ryantwolf committed Apr 22, 2024
1 parent 203f88a commit 5899432
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 0 deletions.
53 changes: 53 additions & 0 deletions examples/blend_and_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import nemo_curator as nc
from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.script_utils import add_distributed_args


def main(args):
# Params
dataset_paths = ["/path/to/first", "/path/to/second", "/path/to/third"]
dataset_weights = [5.0, 2.0, 1.0]
target_size = 1000
output_path = "/path/to/output"

# Set up Dask client
client = get_client(args, args.device)

# Blend the datasets
datasets = [DocumentDataset.read_json(path) for path in dataset_paths]
blended_dataset = nc.blend_datasets(target_size, datasets, dataset_weights)

shuffle = nc.Shuffle(seed=42)
blended_dataset = shuffle(blended_dataset)

# Save the blend
blended_dataset.to_json(output_path)


def attach_args(
parser=argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
),
):
return add_distributed_args(parser)


if __name__ == "__main__":
main(attach_args().parse_args())
66 changes: 66 additions & 0 deletions tests/test_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import dask.dataframe as dd
import pandas as pd
from dask.dataframe.utils import assert_eq

import nemo_curator as nc
from nemo_curator.datasets import DocumentDataset


def list_to_dataset(documents, col_name="text", npartitions=2):
data = {col_name: documents}
pdf = pd.DataFrame(data)

return DocumentDataset(dd.from_pandas(pdf, npartitions=npartitions))


class TestShuffling:
def test_shuffle(self):
original_dataset = list_to_dataset(["one", "two", "three"])
expected_dataset = list_to_dataset(["one", "two", "three"])
shuffle = nc.Shuffle(seed=42)
result_dataset = shuffle(original_dataset)
assert_eq(expected_dataset.df, result_dataset.df)

def test_new_partitions(self):
original_dataset = list_to_dataset(["one", "two", "three"], npartitions=3)
expected_dataset = list_to_dataset(["one", "two", "three"])
shuffle = nc.Shuffle(seed=42, npartitions=2)
result_dataset = shuffle(original_dataset)
assert_eq(expected_dataset.df, result_dataset.df)

def test_filename(self):
original_dataset = list_to_dataset(["one", "two", "three"], npartitions=1)
original_dataset.df["filename"] = "original.jsonl"

expected_data = {
"text": ["one", "two", "three"],
"filename": [
"file_0000000001.jsonl",
"file_0000000001.jsonl",
"file_0000000002.jsonl",
],
}
pdf = pd.DataFrame(expected_data)
expected_dataset = DocumentDataset(dd.from_pandas(pdf, npartitions=2))

shuffle = nc.Shuffle(seed=42, npartitions=2)
result_dataset = shuffle(original_dataset)
assert_eq(expected_dataset.df, result_dataset.df)

def test_custom_filenames(self):
original_dataset = list_to_dataset(["one", "two", "three"], npartitions=1)
original_dataset.df["filename"] = "original.jsonl"

expected_data = {
"text": ["one", "two", "three"],
"filename": ["my_1.test", "my_1.test", "my_2.test"],
}
pdf = pd.DataFrame(expected_data)
expected_dataset = DocumentDataset(dd.from_pandas(pdf, npartitions=2))

def filename_fn(x):
return f"my_{x}.test"

shuffle = nc.Shuffle(seed=42, npartitions=2, partition_to_filename=filename_fn)
result_dataset = shuffle(original_dataset)
assert_eq(expected_dataset.df, result_dataset.df)

0 comments on commit 5899432

Please sign in to comment.