Skip to content

Commit

Permalink
Add dataset blending tool (#32)
Browse files Browse the repository at this point in the history
* Add initial dataset blending function

Signed-off-by: Ryan Wolf <[email protected]>

* Add blend unit tests

Signed-off-by: Ryan Wolf <[email protected]>

* Add self parameter

Signed-off-by: Ryan Wolf <[email protected]>

* Fix return type of blend dataset

Signed-off-by: Ryan Wolf <[email protected]>

* Fix blending tests

Signed-off-by: Ryan Wolf <[email protected]>

* Change assert statement for very uneven blend

Signed-off-by: Ryan Wolf <[email protected]>

* Fix key error

Signed-off-by: Ryan Wolf <[email protected]>

* Add proper proportion blending test

Signed-off-by: Ryan Wolf <[email protected]>

* Add four dataset blend and clarify docs

Signed-off-by: Ryan Wolf <[email protected]>

* Add shuffle module

Signed-off-by: Ryan Wolf <[email protected]>

* Add blend example and tests

Signed-off-by: Ryan Wolf <[email protected]>

* Fix random method name

Signed-off-by: Ryan Wolf <[email protected]>

* Wrap return type in DocumentDataset

Signed-off-by: Ryan Wolf <[email protected]>

* Save result of column drop

Signed-off-by: Ryan Wolf <[email protected]>

* Change equality check for shuffle tests

Signed-off-by: Ryan Wolf <[email protected]>

* Fix expected order after shuffle

Signed-off-by: Ryan Wolf <[email protected]>

* Add more documents to shuffle test

Signed-off-by: Ryan Wolf <[email protected]>

* Add assert statement

Signed-off-by: Ryan Wolf <[email protected]>

* Add within partition shuffle

Signed-off-by: Ryan Wolf <[email protected]>

* Refactor add rand column for shuffle

Signed-off-by: Ryan Wolf <[email protected]>

* Fix filename tests

Signed-off-by: Ryan Wolf <[email protected]>

* Add determinism handling for shuffle

Signed-off-by: Ryan Wolf <[email protected]>

* Change numpy random function

Signed-off-by: Ryan Wolf <[email protected]>

* Fix tests with new random method

Signed-off-by: Ryan Wolf <[email protected]>

* Remove length call from blending

Signed-off-by: Ryan Wolf <[email protected]>

* Improve scaling of blending function

Signed-off-by: Ryan Wolf <[email protected]>

* Fix blend tests

Signed-off-by: Ryan Wolf <[email protected]>

* Add blending script

Signed-off-by: Ryan Wolf <[email protected]>

* Add additional file paths call

Signed-off-by: Ryan Wolf <[email protected]>

* Add documentation

Signed-off-by: Ryan Wolf <[email protected]>

* Reformat docs

Signed-off-by: Ryan Wolf <[email protected]>

* Remove backticks

Signed-off-by: Ryan Wolf <[email protected]>

* Add context manager for shuffle tests

Signed-off-by: Ryan Wolf <[email protected]>

* Add better deterministic shuffle path

Signed-off-by: Ryan Wolf <[email protected]>

* Update documentation and reset index

Signed-off-by: Ryan Wolf <[email protected]>

---------

Signed-off-by: Ryan Wolf <[email protected]>
  • Loading branch information
ryantwolf authored May 3, 2024
1 parent 2bf430c commit f59a799
Show file tree
Hide file tree
Showing 9 changed files with 752 additions and 1 deletion.
84 changes: 84 additions & 0 deletions docs/user-guide/DocumentDataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,87 @@ In these cases, we recommend processing the input dataset in batches using a sim
This will read in 64 shards at a time, process them, and write them back to disk.
Like ``get_remaining_files``, it only includes files that are in the input directory and not in the output directory.

############################
Blending and Shuffling
############################

Blending data from multiple sources can be a great way of improving downstream model performance.
This blending can be done during model training itself (i.e., *online* blending) or it can be done before training (i.e., *offline* blending).
Online blending is useful for rapidly iterating in the training process.
Meanwhile, offline blending is useful if you want to distribute the dataset.
Online blending is currently possible in `NeMo via NVIDIA Megatron Core <https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/datasets/blended_dataset.py>`_, and NeMo Curator offers a way to perform blending offline.

Let's take a look at how datasets can be combined using ``nc.blend_datasets``

.. code-block:: python
import nemo_curator as nc
books = DocumentDataset.read_json("books_dataset/")
articles = DocumentDataset.read_json("articles_dataset/")
journals = DocumentDataset.read_json("journals_dataset/")
datasets = [books, articles, journals]
target_samples = 1000
weights = [5.0, 2.0, 1.0]
blended_dataset = nc.blend_datasets(target_samples, datasets, weights)
blended_dataset.to_json("blended_dataset/")
* ``datasets = [books, articles, journals]`` Here, we are choosing to blend three different datasets.
These datasets do not have to be in the same file format, or similar in size.
So long as they can be read in as a DocumentDataset, they will be fine.
The samples from each dataset are always drawn "in order".
The precise order depends on the format.
For sharded jsonl files, the entries at the beginning of the file with the first name in sorted order will be chosen first.
* ``target_samples = 1000`` This is the desired number of samples in the resulting dataset.
By sample, we mean document or just generally a single datapoint.
There may end up being more samples in the dataset depending on the weights.
* ``weights = [5.0, 2.0, 1.0]`` The relative number of samples that should be taken from each dataset.
Given these weights, the blended dataset will have five times as many samples from books as there are samples from journals.
Similarly, there will be two times as many samples from articles when compared to samples from journals.
Weights can be a list of non-negative real numbers.
``nc.blend_datasets`` will do the normalization and combine the normalized weights with the target samples to determine
how many samples should be taken from each dataset.
In the case of the books dataset, the following would be the calculation.

.. math::
\lceil target\_samples \cdot w_i\rceil=\lceil 1000\cdot \frac{5}{8}\rceil=625
If any datasets have fewer samples than the calculated weight, they will be oversampled to meet the quota.
For example, if the books dataset only had 500 documents in it, the first 125 would be repeated to achieve
the 625 samples.
* ``blended_dataset = nc.blend_datasets(target_samples, datasets, weights)`` We now call the function itself.
Afterwards, we are left with a blended dataset that we can operate on like any other dataset.
We can apply filters, deduplicate, or classify the documents.

Because blending datasets involves combining data from multiple sources, the sharding of the original datasets
cannot be preserved. The options ``add_filename=True`` and ``write_to_filename=True`` for reading and writing
datasets are therefore incompatible with ``nc.blend_datasets``.


Shuffling can be another important aspect of dataset management.
NeMo Curator's ``nc.Shuffle`` allows users to reorder all entries in the dataset.

Here is a small example on how this can be done:

.. code-block:: python
import nemo_curator as nc
books = DocumentDataset.read_json("books_dataset/")
shuffle = nc.Shuffle(seed=42)
shuffled_books = shuffle(books)
shuffled_books.to_json("shuffled_books/")
* ``shuffle = nc.Shuffle(seed=42)`` This creates a shuffle operation that can be chained with
the various other modules in NeMo Curator. In this example, we fix the seed to be 42.
Setting the seed will guarantee determinism, but may be slightly slower (20-30% slower)
depending on the dataset size.
* ``shuffled_books = shuffle(books)`` The dataset has now been shuffled, and we can save it to the filesystem.
53 changes: 53 additions & 0 deletions examples/blend_and_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import nemo_curator as nc
from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.script_utils import add_distributed_args


def main(args):
# Params
dataset_paths = ["/path/to/first", "/path/to/second", "/path/to/third"]
dataset_weights = [5.0, 2.0, 1.0]
target_size = 1000
output_path = "/path/to/output"

# Set up Dask client
client = get_client(args, args.device)

# Blend the datasets
datasets = [DocumentDataset.read_json(path) for path in dataset_paths]
blended_dataset = nc.blend_datasets(target_size, datasets, dataset_weights)

shuffle = nc.Shuffle(seed=42)
blended_dataset = shuffle(blended_dataset)

# Save the blend
blended_dataset.to_json(output_path)


def attach_args(
parser=argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
),
):
return add_distributed_args(parser)


if __name__ == "__main__":
main(attach_args().parse_args())
2 changes: 1 addition & 1 deletion nemo_curator/datasets/doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DocumentDataset:
Internally it may be distributed across multiple nodes, and may be on GPUs.
"""

def __init__(self, dataset_df):
def __init__(self, dataset_df: dd.DataFrame):
self.df = dataset_df

def __len__(self):
Expand Down
3 changes: 3 additions & 0 deletions nemo_curator/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nemo_curator.utils.import_utils import gpu_only_import_from

from .add_id import AddId
from .dataset_ops import blend_datasets, Shuffle
from .exact_dedup import ExactDuplicates
from .filter import Filter, Score, ScoreFilter
from .meta import Sequential
Expand Down Expand Up @@ -50,4 +51,6 @@
"Sequential",
"TaskDecontamination",
"AddId",
"blend_datasets",
"Shuffle",
]
183 changes: 183 additions & 0 deletions nemo_curator/modules/dataset_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import math
from typing import Any, Callable, List, Optional

import dask.dataframe as dd
import numpy as np

from nemo_curator.datasets.doc_dataset import DocumentDataset


def default_filename(partition_num: int) -> str:
return f"file_{partition_num:010d}.jsonl"


class Shuffle:
def __init__(
self,
seed: Optional[int] = None,
npartitions: Optional[int] = None,
partition_to_filename: Callable[[int], str] = default_filename,
) -> None:
"""
Randomly permutes the dataset. This will make the original "filename" column invalid, so if the column is present it will be overwritten.
Args:
seed: The random seed that will be used to determine which partition (file) each datapoint goes to.
Setting the seed will guarantee determinism, but may be slightly slower (20-30% slower)
depending on the dataset size.
npartitions: The output number of partitions to create in the dataset.
If None, it will retain the same number of partitions as the original dataset.
partition_to_filename: If the filename column is present, it will be overwritten.
Passing a function in through this argument allows the user to configure what the filename
will look like given the partition number. The default method names the partition
f'file_{partition_num:010d}.jsonl' and should be changed if the user is not using a .jsonl format.
"""
self.seed = seed
self.npartitions = npartitions
self.partition_to_filename = partition_to_filename
self.rand_col = "_shuffle_rand"

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
if self.seed is None:
return self.shuffle_nondeterministic(dataset)
else:
return self.shuffle_deterministic(dataset)

def shuffle_deterministic(self, dataset: DocumentDataset) -> DocumentDataset:
new_npartitions = (
dataset.df.npartitions if self.npartitions is None else self.npartitions
)

dataset.df[self.rand_col] = dataset.df.map_partitions(self._add_rand_col)

shuffled_df = dataset.df.set_index(self.rand_col, npartitions=new_npartitions)
shuffled_df = shuffled_df.reset_index(drop=True)

if "filename" in shuffled_df:
shuffled_df["filename"] = shuffled_df.map_partitions(self._add_filename)

return DocumentDataset(shuffled_df)

def shuffle_nondeterministic(self, dataset: DocumentDataset) -> DocumentDataset:
new_npartitions = (
dataset.df.npartitions if self.npartitions is None else self.npartitions
)

dataset.df[self.rand_col] = dataset.df.map_partitions(self._add_rand_col)

shuffled_df = dataset.df.shuffle(
self.rand_col, npartitions=new_npartitions, ignore_index=True
)
shuffled_df = shuffled_df.drop(columns=[self.rand_col])
shuffled_df = shuffled_df.map_partitions(self._partition_shuffle)

return DocumentDataset(shuffled_df)

def _add_rand_col(self, partition, partition_info=None):
if partition_info is None:
partition_info = {
"number": 0,
}

if self.seed is not None:
np.random.seed(self.seed + partition_info["number"])
rand_col = np.random.randint(0, np.iinfo("int64").max, size=len(partition))

return rand_col

def _partition_shuffle(self, partition, partition_info=None):
if partition_info is None:
return partition

partition_num = partition_info["number"]
if self.seed is not None:
random_state = self.seed + partition_num
else:
random_state = None

partition = partition.sample(frac=1, random_state=random_state).reset_index(
drop=True
)

if "filename" in partition:
filename = self.partition_to_filename(partition_num)
partition["filename"] = filename

return partition

def _add_filename(self, partition, partition_info=None):
if partition_info is None:
return ["filename"] * len(partition)

filename = self.partition_to_filename(partition_info["number"])

return [filename for _ in range(len(partition))]


def blend_datasets(
target_size: int, datasets: List[DocumentDataset], sampling_weights: List[float]
) -> DocumentDataset:
"""
Combined multiple datasets into one with different amounts of each dataset
Args:
target_size: The number of documents the resulting dataset should have.
The actual size of the dataset may be slightly larger if the normalized weights do not allow
for even mixtures of the datasets.
datasets: A list of all datasets to combine together
sampling_weights: A list of weights to assign to each dataset in the input. Weights will be
normalized across the whole list as a part of the sampling process. For example, if the normalized
sampling weight for dataset 1 is 0.02, 2% ofthe total samples will be sampled from dataset 1.
There are guaranteed to be math.ceil(normalized_weight_i * target_size) elements from dataset i in
the final blend.
"""
if len(datasets) != len(sampling_weights):
raise ValueError(
f"Different number of datasets and weights specified. {len(datasets)} datasets and {len(sampling_weights)}"
)

weight_sum = sum(sampling_weights)
sampling_weights = [weight / weight_sum for weight in sampling_weights]
num_documents_per_dataset = [
math.ceil(weight * target_size) for weight in sampling_weights
]

blend_components = []
for dataset, num_documents in zip(datasets, num_documents_per_dataset):
# Repeatedly sample from the dataset
while num_documents > 0:
sample = _partition_head(dataset.df, num_documents)
blend_components.append(sample)
num_documents -= len(sample)

blended_dataset = dd.concat(blend_components)

return DocumentDataset(blended_dataset)


def _partition_head(ddf: dd.DataFrame, n: int) -> dd.DataFrame:
"""
Returns the first n rows in a dataframe while preserving the partitions.
Meant as a replacement for ddf.head(npartitions=-1, compute=False) as it
uses too much memory at large scales
Args:
ddf: The dataframe to get the first rows from
n: The number of rows to get
"""
original_meta = ddf.dtypes.to_dict()
partition_lengths = ddf.map_partitions(len)
num_partitions = 0
total_size = 0
last_length = 0
for length in partition_lengths:
total_size += length
num_partitions += 1
last_length = length
if total_size >= n:
break

delayed_df = ddf.to_delayed()
excess_elems = max(0, total_size - n)
delayed_df = delayed_df[:num_partitions]
delayed_df[-1] = delayed_df[-1].head(last_length - excess_elems)

return dd.from_delayed(delayed_df, meta=original_meta)
Loading

0 comments on commit f59a799

Please sign in to comment.