Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dataset blending tool #32

Merged
merged 38 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
4a183cf
Add initial dataset blending function
ryantwolf Apr 2, 2024
969123c
Add blend unit tests
ryantwolf Apr 4, 2024
d8e7922
Add self parameter
ryantwolf Apr 4, 2024
fd5f339
Fix return type of blend dataset
ryantwolf Apr 4, 2024
1e1401b
Fix blending tests
ryantwolf Apr 4, 2024
5d966ba
Change assert statement for very uneven blend
ryantwolf Apr 4, 2024
64d5040
Fix key error
ryantwolf Apr 4, 2024
7972669
Add proper proportion blending test
ryantwolf Apr 4, 2024
f603c2c
Add four dataset blend and clarify docs
ryantwolf Apr 4, 2024
e53c1eb
Merge branch 'main' into rywolf/blending
ryantwolf Apr 22, 2024
203f88a
Add shuffle module
ryantwolf Apr 22, 2024
5899432
Add blend example and tests
ryantwolf Apr 22, 2024
f197610
Fix random method name
ryantwolf Apr 22, 2024
2bc5a5f
Wrap return type in DocumentDataset
ryantwolf Apr 22, 2024
bb9803d
Save result of column drop
ryantwolf Apr 22, 2024
414091d
Change equality check for shuffle tests
ryantwolf Apr 22, 2024
88ae8af
Fix expected order after shuffle
ryantwolf Apr 22, 2024
b016672
Add more documents to shuffle test
ryantwolf Apr 22, 2024
48bb22b
Add assert statement
ryantwolf Apr 22, 2024
cd9fdcf
Add within partition shuffle
ryantwolf Apr 22, 2024
491f8b6
Refactor add rand column for shuffle
ryantwolf Apr 23, 2024
89009a8
Fix filename tests
ryantwolf Apr 23, 2024
aaa6493
Add determinism handling for shuffle
ryantwolf Apr 23, 2024
22376f2
Merge branch 'main' into rywolf/blending
ryantwolf Apr 24, 2024
b1e1a42
Change numpy random function
ryantwolf Apr 24, 2024
c5f7b3e
Fix tests with new random method
ryantwolf Apr 24, 2024
5dda670
Remove length call from blending
ryantwolf Apr 26, 2024
6cee1fc
Improve scaling of blending function
ryantwolf Apr 29, 2024
2a1c7cb
Fix blend tests
ryantwolf Apr 29, 2024
7025d88
Add blending script
ryantwolf Apr 30, 2024
b7d3f50
Add additional file paths call
ryantwolf Apr 30, 2024
e7c2a9c
Add documentation
ryantwolf Apr 30, 2024
6ddff54
Reformat docs
ryantwolf Apr 30, 2024
ccc6e0c
Remove backticks
ryantwolf Apr 30, 2024
0d81ace
Merge branch 'main' into rywolf/blending
ryantwolf Apr 30, 2024
53fcda9
Add context manager for shuffle tests
ryantwolf Apr 30, 2024
2c578cb
Add better deterministic shuffle path
ryantwolf May 3, 2024
fa25b33
Update documentation and reset index
ryantwolf May 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions docs/user-guide/DocumentDataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,88 @@ In these cases, we recommend processing the input dataset in batches using a sim

This will read in 64 shards at a time, process them, and write them back to disk.
Like ``get_remaining_files``, it only includes files that are in the input directory and not in the output directory.

############################
Blending and Shuffling
############################

Blending data from multiple sources can be a great way of improving downstream model performance.
This blending can be done during model training itself (i.e., *online* blending) or it can be done before training (i.e., *offline* blending).
Online blending is useful for rapidly iterating in the training process.
Meanwhile, offline blending is useful if you want to distribute the dataset.
Online blending is currently possible in NeMo, and NeMo Curator offers a way to perform blending offline.
ryantwolf marked this conversation as resolved.
Show resolved Hide resolved

Let's take a look at how datasets can be combined using ``nc.blend_datasets``

.. code-block:: python

import nemo_curator as nc

books = DocumentDataset.read_json("books_dataset/")
articles = DocumentDataset.read_json("articles_dataset/")
journals = DocumentDataset.read_json("journals_dataset/")

datasets = [books, articles, journals]
target_samples = 1000
weights = [5.0, 2.0, 1.0]

blended_dataset = nc.blend_datasets(target_samples, datasets, weights)

blended_dataset.to_json("blended_dataset/")


* ``datasets = [books, articles, journals]`` Here, we are choosing to blend three different datasets.
These datasets do not have to be in the same file format, or similar in size.
So long as they can be read in as a DocumentDataset, they will be fine.
The samples from each dataset are always drawn "in order".
The precise order depends on the format.
For sharded jsonl files, the entries at the beginning of the file with the first name in sorted order will be chosen first.
* ``target_samples = 1000`` This is the desired number of samples in the resulting dataset.
By sample, we mean document or just generally a single datapoint.
There may end up being more samples in the dataset depending on the weights.
* ``weights = [5.0, 2.0, 1.0]`` The relative number of samples that should be taken from each dataset.
Given these weights, the blended dataset will have five times as many samples from books as there are samples from journals.
Similarly, there will be two times as many samples from articles when compared to samples from journals.
Weights can be a list of non-negative real numbers.
``nc.blend_datasets`` will do the normalization and combine the normalized weights with the target samples to determine
how many samples should be taken from each dataset.
In the case of the books dataset, the following would be the calculation.

.. math::

\lceil target\_samples \cdot w_i\rceil=\lceil 1000\cdot \frac{5}{8}\rceil=625
If any datasets have fewer samples than the calculated weight, they will be oversampled to meet the quota.
For example, if the books dataset only had 500 documents in it, the first 125 would be repeated to achieve
the 625 samples.
* ``blended_dataset = nc.blend_datasets(target_samples, datasets, weights)`` We now call the function itself.
Afterwards, we are left with a blended dataset that we can operate on like any other dataset.
We can apply filters, deduplicate, or classify the documents.

Because blending datasets involves combining data from multiple sources, the sharding of the original datasets
cannot be preserved. The options ``add_filename=True`` and ``write_to_filename=True`` for reading and writing
datasets are therefore incompatible with ``nc.blend_datasets``.


Shuffling can be another important aspect of dataset management.
NeMo Curator's ``nc.Shuffle`` allows users to reorder all entries in the dataset.

Here is a small example on how this can be done:

.. code-block:: python

import nemo_curator as nc

books = DocumentDataset.read_json("books_dataset/")

shuffle = nc.Shuffle(seed=42)

shuffled_books = shuffle(books)

shuffled_books.to_json("shuffled_books/")

* ``shuffle = nc.Shuffle(seed=42)`` This creates a shuffle operation that can be chained with
the various other modules in NeMo Curator. In this example, we fix the seed to be 42.
Even with a fixed seed, shuffling is only guaranteed to be deterministic if done with a single-threaded
and single process client. Dask allows us to perform the shuffling in parallel for speed gains, so using a
ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
multiprocessing client is recommended where determinism is optional.
* ``shuffled_books = shuffle(books)`` The dataset has now been shuffled, and we can save it to the filesystem.
53 changes: 53 additions & 0 deletions examples/blend_and_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import nemo_curator as nc
from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.script_utils import add_distributed_args


def main(args):
# Params
dataset_paths = ["/path/to/first", "/path/to/second", "/path/to/third"]
dataset_weights = [5.0, 2.0, 1.0]
target_size = 1000
output_path = "/path/to/output"

# Set up Dask client
client = get_client(args, args.device)

# Blend the datasets
datasets = [DocumentDataset.read_json(path) for path in dataset_paths]
blended_dataset = nc.blend_datasets(target_size, datasets, dataset_weights)

shuffle = nc.Shuffle(seed=42)
blended_dataset = shuffle(blended_dataset)

# Save the blend
blended_dataset.to_json(output_path)


def attach_args(
parser=argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
),
):
return add_distributed_args(parser)


if __name__ == "__main__":
main(attach_args().parse_args())
2 changes: 1 addition & 1 deletion nemo_curator/datasets/doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DocumentDataset:
Internally it may be distributed across multiple nodes, and may be on GPUs.
"""

def __init__(self, dataset_df):
def __init__(self, dataset_df: dd.DataFrame):
self.df = dataset_df

def __len__(self):
Expand Down
3 changes: 3 additions & 0 deletions nemo_curator/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nemo_curator.utils.import_utils import gpu_only_import_from

from .add_id import AddId
from .dataset_ops import blend_datasets, Shuffle
from .exact_dedup import ExactDuplicates
from .filter import Filter, Score, ScoreFilter
from .meta import Sequential
Expand Down Expand Up @@ -50,4 +51,6 @@
"Sequential",
"TaskDecontamination",
"AddId",
"blend_datasets",
"Shuffle",
]
150 changes: 150 additions & 0 deletions nemo_curator/modules/dataset_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import math
from typing import Any, Callable, List, Optional

import dask.dataframe as dd
import numpy as np

from nemo_curator.datasets.doc_dataset import DocumentDataset


def default_filename(partition_num: int) -> str:
return f"file_{partition_num:010d}.jsonl"


class Shuffle:
def __init__(
self,
seed: Optional[int] = None,
npartitions: Optional[int] = None,
partition_to_filename: Callable[[int], str] = default_filename,
) -> None:
"""
Randomly permutes the dataset. This will make the original "filename" column invalid, so if the column is present it will be overwritten.
Args:
seed: The random seed that will be used to determine which partition (file) each datapoint goes to.
Even with the seed set, the shuffle is not guaranteed to be deterministic if done in parallel.
ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
Intiailize the Dask client with a single worker and single thread in order to ensure determinism.
npartitions: The output number of partitions to create in the dataset.
If None, it will retain the same number of partitions as the original dataset.
partition_to_filename: If the filename column is present, it will be overwritten.
Passing a function in through this argument allows the user to configure what the filename
will look like given the partition number. The default method names the partition
f'file_{partition_num:010d}.jsonl' and should be changed if the user is not using a .jsonl format.
"""
self.seed = seed
self.npartitions = npartitions
self.partition_to_filename = partition_to_filename
self.rand_col = "_shuffle_rand"

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
new_npartitions = (
dataset.df.npartitions if self.npartitions is None else self.npartitions
)

rand_df = dataset.df.map_partitions(self._add_rand_col, new_npartitions)

shuffled_df = rand_df.shuffle(
self.rand_col, npartitions=new_npartitions, ignore_index=True
ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
)
shuffled_df = shuffled_df.drop(columns=[self.rand_col])
shuffled_df = shuffled_df.map_partitions(self._partition_shuffle)

return DocumentDataset(shuffled_df)

def _add_rand_col(self, partition, new_npartitions, partition_info=None):
if partition_info is None:
partition_info = {
"number": 0,
}

np.random.seed(self.seed + partition_info["number"])
partition[self.rand_col] = np.random.randint(
0, new_npartitions, size=len(partition)
)

return partition

def _partition_shuffle(self, partition, partition_info=None):
if partition_info is None:
return partition

partition_num = partition_info["number"]
partition = partition.sample(
frac=1, random_state=self.seed + partition_num
).reset_index(drop=True)

if "filename" in partition:
filename = self.partition_to_filename(partition_num)
partition["filename"] = filename

return partition


def blend_datasets(
target_size: int, datasets: List[DocumentDataset], sampling_weights: List[float]
) -> DocumentDataset:
"""
Combined multiple datasets into one with different amounts of each dataset
Args:
target_size: The number of documents the resulting dataset should have.
The actual size of the dataset may be slightly larger if the normalized weights do not allow
for even mixtures of the datasets.
datasets: A list of all datasets to combine together
sampling_weights: A list of weights to assign to each dataset in the input. Weights will be
normalized across the whole list as a part of the sampling process. For example, if the normalized
sampling weight for dataset 1 is 0.02, 2% ofthe total samples will be sampled from dataset 1.
There are guaranteed to be math.ceil(normalized_weight_i * target_size) elements from dataset i in
the final blend.
"""
if len(datasets) != len(sampling_weights):
raise ValueError(
f"Different number of datasets and weights specified. {len(datasets)} datasets and {len(sampling_weights)}"
)

weight_sum = sum(sampling_weights)
sampling_weights = [weight / weight_sum for weight in sampling_weights]
num_documents_per_dataset = [
math.ceil(weight * target_size) for weight in sampling_weights
]

blend_components = []
for dataset, num_documents in zip(datasets, num_documents_per_dataset):
# Repeatedly sample from the dataset
while num_documents > 0:
sample = _partition_head(dataset.df, num_documents)
blend_components.append(sample)
num_documents -= len(sample)

blended_dataset = dd.concat(blend_components)

return DocumentDataset(blended_dataset)


def _partition_head(ddf: dd.DataFrame, n: int) -> dd.DataFrame:
"""
Returns the first n rows in a dataframe while preserving the partitions.
Meant as a replacement for ddf.head(npartitions=-1, compute=False) as it
uses too much memory at large scales

Args:
ddf: The dataframe to get the first rows from
n: The number of rows to get
"""
original_meta = ddf.dtypes.to_dict()
partition_lengths = ddf.map_partitions(len)
num_partitions = 0
total_size = 0
last_length = 0
for length in partition_lengths:
total_size += length
num_partitions += 1
last_length = length
if total_size >= n:
break

delayed_df = ddf.to_delayed()
excess_elems = max(0, total_size - n)
delayed_df = delayed_df[:num_partitions]
delayed_df[-1] = delayed_df[-1].head(last_length - excess_elems)

return dd.from_delayed(delayed_df, meta=original_meta)
Loading