diff --git a/config/fasttext_langid.yaml b/config/fasttext_langid.yaml
index 86b18761d..a1f4f3530 100644
--- a/config/fasttext_langid.yaml
+++ b/config/fasttext_langid.yaml
@@ -1,5 +1,6 @@
 input_field: text
 filters:
   - name: nemo_curator.filters.classifier_filter.FastTextLangId
+    log_score: True
     params:
       model_path: <Path to the FasText language id model (e.g., lid.176.bin)>
diff --git a/nemo_curator/__init__.py b/nemo_curator/__init__.py
index 000e459a9..4645d55ef 100644
--- a/nemo_curator/__init__.py
+++ b/nemo_curator/__init__.py
@@ -12,4 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import dask
+
 from .modules import *
+
+# Dask will automatically convert the list score type
+# to a string without this option.
+# See https://github.com/NVIDIA/NeMo-Curator/issues/33
+# This also happens when reading and writing to files
+dask.config.set({"dataframe.convert-string": False})
diff --git a/nemo_curator/_compat.py b/nemo_curator/_compat.py
index 1dc07d9e0..a89426d52 100644
--- a/nemo_curator/_compat.py
+++ b/nemo_curator/_compat.py
@@ -20,3 +20,4 @@
 # TODO: remove when dask min version gets bumped
 DASK_SHUFFLE_METHOD_ARG = _dask_version > parseVersion("2024.1.0")
 DASK_P2P_ERROR = _dask_version < parseVersion("2023.10.0")
+DASK_SHUFFLE_CAST_DTYPE = _dask_version > parseVersion("2023.12.0")
diff --git a/nemo_curator/filters/classifier_filter.py b/nemo_curator/filters/classifier_filter.py
index 4f06c8b25..741df9640 100644
--- a/nemo_curator/filters/classifier_filter.py
+++ b/nemo_curator/filters/classifier_filter.py
@@ -76,11 +76,6 @@ def __init__(self, model_path=None, min_langid_score=0.3):
         self._cutoff = min_langid_score
         self._name = "lang_id"
 
-        # Dask will automatically convert the list score type
-        # to a string without this option.
-        # See https://github.com/NVIDIA/NeMo-Curator/issues/33
-        dask.config.set({"dataframe.convert-string": False})
-
     @batched
     def score_document(self, df: pd.Series):
         model_attr = f"{self._name}_{self._model_path}"
diff --git a/nemo_curator/modifiers/pii_modifier.py b/nemo_curator/modifiers/pii_modifier.py
index c2a398b48..51ea5b6e2 100644
--- a/nemo_curator/modifiers/pii_modifier.py
+++ b/nemo_curator/modifiers/pii_modifier.py
@@ -17,7 +17,7 @@
 import pandas as pd
 
 from nemo_curator.modifiers import DocumentModifier
-from nemo_curator.pii.algorithm import DEFAULT_LANGUAGE
+from nemo_curator.pii.constants import DEFAULT_LANGUAGE, DEFAULT_MAX_DOC_SIZE
 from nemo_curator.utils.decorators import batched
 from nemo_curator.utils.distributed_utils import load_object_on_worker
 
@@ -97,7 +97,7 @@ def load_deidentifier(self):
 
         if self.device == "gpu":
             spacy.require_gpu()
-        from nemo_curator.pii.algorithm import DEFAULT_MAX_DOC_SIZE, PiiDeidentifier
+        from nemo_curator.pii.algorithm import PiiDeidentifier
 
         deidentifier: PiiDeidentifier = PiiDeidentifier(
             language=self.language,
diff --git a/nemo_curator/pii/algorithm.py b/nemo_curator/pii/algorithm.py
index 762214efb..2b5e16ed0 100644
--- a/nemo_curator/pii/algorithm.py
+++ b/nemo_curator/pii/algorithm.py
@@ -15,6 +15,10 @@
 from pathlib import Path
 from typing import Any, List, Mapping, Union
 
+# NOTE: Importing this module before cluster creation will create a primary CUDA context
+# that leads to issues of all GPUs not being used when creating a cluster/client later on.
+# Ensure that this module is always imported after cluster creation only when the algorithm
+# needs to be executed. See: https://github.com/NVIDIA/NeMo-Curator/issues/64
 import yaml
 from presidio_analyzer import AnalyzerEngine, RecognizerRegistry
 from presidio_analyzer.nlp_engine import NerModelConfiguration
@@ -30,36 +34,16 @@
 from presidio_anonymizer import AnonymizerEngine, BatchAnonymizerEngine
 from presidio_anonymizer.entities import OperatorConfig
 
+from nemo_curator.pii.constants import DEFAULT_LANGUAGE, SUPPORTED_ENTITIES
 from nemo_curator.pii.custom_batch_analyzer_engine import CustomBatchAnalyzerEngine
 from nemo_curator.pii.custom_nlp_engine import CustomNlpEngine
 from nemo_curator.pii.recognizers.address_recognizer import AddressRecognizer
 
 __all__ = [
-    "DEFAULT_LANGUAGE",
-    "SUPPORTED_ENTITIES",
-    "DEFAULT_MAX_DOC_SIZE",
     "PiiDeidentifier",
 ]
 
 
-DEFAULT_LANGUAGE = "en"
-SUPPORTED_ENTITIES = [
-    "ADDRESS",
-    "CREDIT_CARD",
-    "EMAIL_ADDRESS",
-    "DATE_TIME",
-    "IP_ADDRESS",
-    "LOCATION",
-    "PERSON",
-    "URL",
-    "US_SSN",
-    "US_PASSPORT",
-    "US_DRIVER_LICENSE",
-    "PHONE_NUMBER",
-]
-DEFAULT_MAX_DOC_SIZE = 2000000
-
-
 class PiiDeidentifier(object):
     """Cleans PII from an unstructured text"""
 
diff --git a/nemo_curator/pii/constants.py b/nemo_curator/pii/constants.py
new file mode 100644
index 000000000..fc8dcc545
--- /dev/null
+++ b/nemo_curator/pii/constants.py
@@ -0,0 +1,20 @@
+DEFAULT_LANGUAGE = "en"
+
+SUPPORTED_ENTITIES = [
+    "ADDRESS",
+    "CREDIT_CARD",
+    "EMAIL_ADDRESS",
+    "DATE_TIME",
+    "IP_ADDRESS",
+    "LOCATION",
+    "PERSON",
+    "URL",
+    "US_SSN",
+    "US_PASSPORT",
+    "US_DRIVER_LICENSE",
+    "PHONE_NUMBER",
+]
+
+DEFAULT_MAX_DOC_SIZE = 2000000
+
+__all__ = ["DEFAULT_LANGUAGE", "SUPPORTED_ENTITIES", "DEFAULT_MAX_DOC_SIZE"]
diff --git a/nemo_curator/utils/file_utils.py b/nemo_curator/utils/file_utils.py
index af3c2513d..3ec466b4c 100644
--- a/nemo_curator/utils/file_utils.py
+++ b/nemo_curator/utils/file_utils.py
@@ -181,9 +181,8 @@ def parse_str_of_num_bytes(s, return_str=False):
 def _save_jsonl(documents, output_path, start_index=0, max_index=10000, prefix=None):
     """Worker function to write out the data to jsonl files"""
 
-    def _output_json(document):
-        myjson = json.dumps(document, ensure_ascii=False)
-        return myjson.encode("utf-8")
+    def _encode_text(document):
+        return document.strip().encode("utf-8")
 
     def _name(start_index, npad, prefix, i):
         tag = str(start_index + i).rjust(npad, "0")
@@ -195,11 +194,22 @@ def _name(start_index, npad, prefix, i):
 
     output_glob_string = os.path.join(output_path, "*.jsonl")
 
-    documents.map(_output_json).to_textfiles(
+    output_files = documents.map(_encode_text).to_textfiles(
         output_glob_string,
         name_function=name,
     )
 
+    # Delete empty files generated due to empty partitions in the bag
+    for output_file in output_files:
+        try:
+            if os.path.getsize(output_file) == 0:
+                os.remove(output_file)
+        except Exception as exception:
+            print(
+                f"An exception occurred when trying to delete {output_file}.\n{exception}",
+                flush=True,
+            )
+
 
 def reshard_jsonl(
     input_dir, output_dir, output_file_size="100M", start_index=0, file_prefix=""
@@ -212,7 +222,8 @@ def reshard_jsonl(
         output_dir: The output directory where the resharded jsonl files will be written
         output_file_size: Approximate size of output files. Must specify with a string and
             with the unit K, M or G for kilo, mega or gigabytes
-        start_index: Starting index for naming the output files
+        start_index: Starting index for naming the output files. Note: The indices may not
+            be continuous if the sharding process would output an empty file in its place
         file_prefix: Prefix to use to prepend to output file number
     """
 
@@ -222,7 +233,7 @@ def reshard_jsonl(
     input_files = list(get_all_files_paths_under(input_dir))
 
     # Read in the dask bag
-    b = db.read_text(input_files, blocksize=blocksize).map(json.loads)
+    b = db.read_text(input_files, blocksize=blocksize)
 
     # Prepare the output
     output_dir = expand_outdir_and_mkdir(output_dir)
diff --git a/nemo_curator/utils/fuzzy_dedup_utils/merge_utils.py b/nemo_curator/utils/fuzzy_dedup_utils/merge_utils.py
index a144b5602..70bf73004 100644
--- a/nemo_curator/utils/fuzzy_dedup_utils/merge_utils.py
+++ b/nemo_curator/utils/fuzzy_dedup_utils/merge_utils.py
@@ -16,13 +16,14 @@
 from operator import getitem
 
 import numpy as np
+import pandas as pd
 from dask.base import tokenize
 from dask.dataframe.core import new_dd_object
 from dask.dataframe.shuffle import partitioning_index
 from dask.highlevelgraph import HighLevelGraph
 from dask.utils import M
 
-from nemo_curator.utils.fuzzy_dedup_utils.shuffle_utils import rearange_by_column_direct
+from nemo_curator._compat import DASK_SHUFFLE_CAST_DTYPE
 
 
 def _split_part(part, nsplits):
@@ -129,6 +130,21 @@ def extract_partitioning_index(
     # a partition-wise merge between `left_df` and `right_df`.
     # We call this `global_partitioning_index`:
 
+    if DASK_SHUFFLE_CAST_DTYPE:
+        # Need to use the same type-casting logic as `shuffle`
+        dtypes = {}
+        if not isinstance(merge_on, list):
+            merge_on = [merge_on]
+        for col, dtype in left_df[merge_on].dtypes.items():
+            if pd.api.types.is_numeric_dtype(dtype):
+                dtypes[col] = np.float64
+        if not dtypes:
+            dtypes = None
+        cast_dtype = {"cast_dtype": dtypes}
+    else:
+        # `cast_dtype` argument doesn't exist yet
+        cast_dtype = {}
+
     num_bucket_files = bk_mapping.file_id.max() + 1
     global_partitioning_index = left_df[merge_on].map_partitions(
         partitioning_index,
@@ -137,6 +153,7 @@ def extract_partitioning_index(
         enforce_metadata=False,
         transform_divisions=False,
         align_dataframes=False,
+        **cast_dtype,
     )
 
     if total_bucket_partitions < num_bucket_files:
@@ -157,7 +174,7 @@ def extract_partitioning_index(
     # want to send the rows of `left_df` to the partition
     # indices encoded in `global_partitioning_index`. Instead, we
     # need to take a modulus with `parts_per_bucket_batch` to
-    # define a `"_partitoins"` column.
+    # define a `"_partitions"` column.
     left_df["_partitions"] = global_partitioning_index % parts_per_bucket_batch
 
     return left_df, global_partitioning_index
@@ -195,6 +212,10 @@ def merge_left_to_shuffled_right(
     subset_bucket_df,
     merge_on,
 ):
+    from nemo_curator.utils.fuzzy_dedup_utils.shuffle_utils import (
+        rearange_by_column_direct,
+    )
+
     # We are merging an unshuffled batch of "left" partitions
     # with a shuffled batch of "right" partitions. To minimize
     # data movement, we can manaully rerrange the "left" batch
diff --git a/tests/test_fuzzy_dedup.py b/tests/test_fuzzy_dedup.py
index e89f998e0..1c952d27d 100644
--- a/tests/test_fuzzy_dedup.py
+++ b/tests/test_fuzzy_dedup.py
@@ -16,14 +16,17 @@
 from itertools import combinations
 from typing import Iterable
 
+import dask.dataframe as dd
 import numpy as np
 import pytest
 import yaml
+from dask import config
 from dask.dataframe.utils import assert_eq
 from distributed import Client
 
 from nemo_curator import LSH, FuzzyDuplicates, FuzzyDuplicatesConfig, MinHash
 from nemo_curator.datasets import DocumentDataset
+from nemo_curator.utils.fuzzy_dedup_utils.merge_utils import extract_partitioning_index
 from nemo_curator.utils.import_utils import gpu_only_import, gpu_only_import_from
 
 cudf = gpu_only_import("cudf")
@@ -367,3 +370,74 @@ def test_from_yaml(self, tmpdir):
         config = FuzzyDuplicatesConfig.from_yaml(tmpdir / "config.yaml")
         for param in yaml_params:
             assert getattr(config, param) == yaml_params[param]
+
+
+@pytest.mark.parametrize(
+    "backend",
+    [
+        "pandas",
+        pytest.param(
+            "cudf",
+            marks=pytest.mark.gpu,
+        ),
+    ],
+)
+def test_extract_partitioning_index(backend):
+
+    def add_partition_info(df, partition_info=None):
+        if partition_info is None:
+            df["file_id"] = -1
+        else:
+            df["file_id"] = partition_info["number"]
+        return df
+
+    with config.set({"dataframe.backend": backend}):
+
+        # Create a random `unshuffled` DataFrame with a
+        # "part_id" column to be used as the shuffle index
+        npartitions_left = 7
+        unshuffled = dd.from_dict(
+            {"part_id": np.random.randint(25, size=1000, dtype="int32")},
+            npartitions=npartitions_left,
+        )
+
+        # Create a `bk_mapping` DataFrame that defines
+        # the "correct" mapping beween "part_id" and
+        # the destination partition ("file_id")
+        npartitions_right = 5
+        bk_mapping = (
+            dd.from_dict(
+                {"part_id": np.arange(25, dtype="int32")},
+                npartitions=npartitions_right,
+            )
+            .shuffle("part_id")
+            .map_partitions(add_partition_info)
+            .compute()
+        )
+
+    # Use `extract_partitioning_index` to calculate
+    # the partitioning index and assign it as a new
+    # "_partitions" column
+    result, _ = extract_partitioning_index(
+        unshuffled,
+        "part_id",
+        bk_mapping,
+        npartitions_right,
+        npartitions_right,
+    )
+
+    # Rename the "_partitions" column, shuffle by "part_id",
+    # and then assign a "file_id" column to reflect the final
+    # partition of each row
+    check = (
+        result.rename(columns={"_partitions": "expected_file_id"})
+        .shuffle(
+            "part_id",
+            npartitions=npartitions_right,
+        )
+        .map_partitions(add_partition_info)
+        .compute()
+    )
+
+    # Check that the real and expected partitions match
+    assert (check["file_id"] == check["expected_file_id"]).all()
diff --git a/tests/test_pii_accuracy.py b/tests/test_pii_accuracy.py
index 7e7d58663..850dafd54 100644
--- a/tests/test_pii_accuracy.py
+++ b/tests/test_pii_accuracy.py
@@ -17,7 +17,6 @@
 from pathlib import Path
 
 import pandas as pd
-import pytest
 from dask import dataframe as dd
 from dask.distributed import Client, LocalCluster
 
diff --git a/tutorials/peft-curation/README.md b/tutorials/peft-curation/README.md
new file mode 100644
index 000000000..afa0d66a3
--- /dev/null
+++ b/tutorials/peft-curation/README.md
@@ -0,0 +1,19 @@
+# Curating Datasets for Parameter Efficient Fine-tuning
+
+This tutorial demonstrates the usage of NeMo Curator's Python API to curate a dataset for
+parameter-efficient fine-tuning (PEFT).
+
+In this tutorial, we use the [Enron Emails dataset](https://huggingface.co/datasets/neelblabla/enron_labeled_emails_with_subjects-llama2-7b_finetuning),
+which is a dataset of emails with corresponding classification labels for each email. Each email has
+a subject, a body and a category (class label). We demonstrate various filtering and processing
+operations that can be applied to each record.
+
+## Usage
+After installing the NeMo Curator package, you can simply run the following command:
+```
+python tutorials/peft-curation/main.py
+```
+
+By default, this tutorial will use at most 8 workers to run the curation pipeline. If you face any
+out of memory issues, you can reduce the number of workers by supplying the `--n-workers=N` argument,
+where `N` is the number of workers to spawn.
diff --git a/tutorials/peft-curation/docbuilder.py b/tutorials/peft-curation/docbuilder.py
new file mode 100644
index 000000000..3ae0840c9
--- /dev/null
+++ b/tutorials/peft-curation/docbuilder.py
@@ -0,0 +1,113 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+from typing import Dict
+
+import requests
+
+from nemo_curator.download.doc_builder import (
+    DocumentDownloader,
+    DocumentExtractor,
+    DocumentIterator,
+)
+
+
+class EmailsDownloader(DocumentDownloader):
+    def __init__(self, download_dir: str):
+        super().__init__()
+
+        if not os.path.isdir(download_dir):
+            os.makedirs(download_dir)
+
+        self._download_dir = download_dir
+        print("Download directory: ", self._download_dir)
+
+    def download(self, url: str) -> str:
+        filename = os.path.basename(url)
+        output_file = os.path.join(self._download_dir, filename)
+
+        if os.path.exists(output_file):
+            print(f"File '{output_file}' already exists, skipping download.")
+            return output_file
+
+        print(f"Downloading Enron emails dataset from '{url}'...")
+        response = requests.get(url)
+
+        with open(output_file, "wb") as file:
+            file.write(response.content)
+
+        return output_file
+
+
+class EmailsIterator(DocumentIterator):
+
+    def __init__(self):
+        super().__init__()
+        self._counter = -1
+        self._extractor = EmailsExtractor()
+        # The regular expression pattern to extract each email.
+        self._pattern = re.compile(r"\"<s>.*?<s>\"", re.DOTALL)
+
+    def iterate(self, file_path):
+        self._counter = -1
+        file_name = os.path.basename(file_path)
+
+        with open(file_path, "r", encoding="utf-8") as file:
+            lines = file.readlines()
+
+        # Ignore the first line which contains the header.
+        file_content = "".join(lines[1:])
+        # Find all the emails in the file.
+        it = self._pattern.finditer(file_content)
+
+        for email in it:
+            self._counter += 1
+            content = email.group().strip('"').strip()
+            meta = {
+                "filename": file_name,
+                "id": f"email-{self._counter}",
+            }
+            extracted_content = self._extractor.extract(content)
+
+            # Skip if no content extracted
+            if not extracted_content:
+                continue
+
+            record = {**meta, **extracted_content}
+            yield record
+
+
+class EmailsExtractor(DocumentExtractor):
+    def __init__(self):
+        super().__init__()
+        # The regular expression pattern to extract subject/body/label into groups.
+        self._pattern = re.compile(
+            r"Subject:: (.*?)\nBody:: (.*?)\n.*\[/INST\] (.*?) <s>", re.DOTALL
+        )
+
+    def extract(self, content: str) -> Dict[str, str]:
+        matches = self._pattern.findall(content)
+
+        if not matches:
+            return None
+
+        matches = matches[0]
+
+        return {
+            "subject": matches[0].strip(),
+            "body": matches[1].strip(),
+            "category": matches[2].strip(),
+        }
diff --git a/tutorials/peft-curation/filters.py b/tutorials/peft-curation/filters.py
new file mode 100644
index 000000000..0ffcd5be7
--- /dev/null
+++ b/tutorials/peft-curation/filters.py
@@ -0,0 +1,47 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nemo_curator.filters import DocumentFilter
+
+
+class FilterEmailsWithLongBody(DocumentFilter):
+    """
+    If the email is too long, discard.
+    """
+
+    def __init__(self, max_length: int = 5000):
+        super().__init__()
+        self.max_length = max_length
+
+    def score_document(self, text: str) -> bool:
+        return len(text) <= self.max_length
+
+    def keep_document(self, score) -> bool:
+        return score
+
+
+class FilterEmptyEmails(DocumentFilter):
+    """
+    Detects empty emails (either empty body, or labeled as empty). Returns `True` for empty emails.
+    """
+
+    def score_document(self, text: str) -> bool:
+        return (
+            not isinstance(text, str)  # The text is not a string
+            or len(text.strip()) == 0  # The text is empty
+            or "Empty message" in text  # The email is labeled as empty
+        )
+
+    def keep_document(self, score) -> bool:
+        return score
diff --git a/tutorials/peft-curation/main.py b/tutorials/peft-curation/main.py
new file mode 100644
index 000000000..9210d9f89
--- /dev/null
+++ b/tutorials/peft-curation/main.py
@@ -0,0 +1,179 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import json
+import os
+from functools import partial
+from typing import Any
+
+from docbuilder import EmailsDownloader, EmailsIterator
+from filters import FilterEmailsWithLongBody, FilterEmptyEmails
+from modifiers import AddPeriod, AddSystemPrompt
+
+from nemo_curator import ScoreFilter, Sequential
+from nemo_curator.datasets import DocumentDataset
+from nemo_curator.modifiers.pii_modifier import PiiModifier
+from nemo_curator.modifiers.unicode_reformatter import UnicodeReformatter
+from nemo_curator.modules.modify import Modify
+from nemo_curator.utils.distributed_utils import get_client
+from nemo_curator.utils.script_utils import add_distributed_args
+
+SCRIPT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
+DATA_DIR = os.path.join(SCRIPT_DIR_PATH, "data")
+DATASET_URL = "https://huggingface.co/datasets/neelblabla/enron_labeled_emails_with_subjects-llama2-7b_finetuning/raw/main/prompts_train.csv"
+
+
+def download_and_convert_to_jsonl() -> str:
+    """
+    Downloads the emails dataset and converts it to JSONL format.
+
+    Returns:
+        str: The path to the JSONL file.
+    """
+
+    # Download the dataset in raw format and convert it to JSONL.
+    downloader = EmailsDownloader(DATA_DIR)
+    output_path = os.path.join(DATA_DIR, "emails.jsonl")
+    raw_fp = downloader.download(DATASET_URL)
+
+    iterator = EmailsIterator()
+
+    # Parse the raw data and write it to a JSONL file.
+    with open(output_path, "w") as f:
+        for record in iterator.iterate(raw_fp):
+            json_record = json.dumps(record, ensure_ascii=False)
+            f.write(json_record + "\n")
+
+    return output_path
+
+
+def redact_pii(dataset: DocumentDataset, text_field) -> DocumentDataset:
+    """
+    Redacts personally identifiable information (PII) from a given dataset.
+
+    Args:
+        dataset (DocumentDataset): The dataset containing documents with PII.
+
+    Returns:
+        DocumentDataset: The redacted dataset with PII replaced by a generic value.
+    """
+    redactor = Modify(
+        PiiModifier(
+            supported_entities=[
+                "ADDRESS",
+                "EMAIL_ADDRESS",
+                "LOCATION",
+                "PERSON",
+                "URL",
+                "PHONE_NUMBER",
+            ],
+            anonymize_action="replace",
+            device="cpu",
+        ),
+        text_field=text_field,
+    )
+    return redactor(dataset)
+
+
+def run_curation_pipeline(args: Any, jsonl_fp: str) -> str:
+    """
+    Run the curation pipeline on the dataset.
+
+    Args:
+        args (Any): Command-line arguments.
+        jsonl_fp (str): The path to the uncurated JSONL file.
+
+    Returns:
+        str: The path to the curated JSONL file.
+    """
+    client = get_client(args, args.device)
+    print(f"    Running the curation pipeline on '{jsonl_fp}'...")
+    orig_dataset = DocumentDataset.read_json(jsonl_fp, add_filename=True)
+    dataset = orig_dataset
+
+    redact_pii_subject = partial(redact_pii, text_field="subject")
+    redact_pii_body = partial(redact_pii, text_field="body")
+
+    curation_steps = Sequential(
+        [
+            #
+            # Unify the text encoding to Unicode.
+            #
+            Modify(UnicodeReformatter(), text_field="subject"),
+            Modify(UnicodeReformatter(), text_field="body"),
+            Modify(UnicodeReformatter(), text_field="category"),
+            #
+            # Filtering
+            #
+            # Filter out empty emails.
+            ScoreFilter(
+                FilterEmptyEmails(), text_field="subject", score_type=bool, invert=True
+            ),
+            ScoreFilter(
+                FilterEmptyEmails(), text_field="body", score_type=bool, invert=True
+            ),
+            ScoreFilter(
+                FilterEmptyEmails(), text_field="category", score_type=bool, invert=True
+            ),
+            # Filter out emails that are too long.
+            ScoreFilter(FilterEmailsWithLongBody(), text_field="body", score_type=bool),
+            #
+            # Redact personally identifiable information (PII).
+            #
+            redact_pii_subject,
+            redact_pii_body,
+            #
+            # Final modifications.
+            #
+            # Add system prompts to every email, which helps the model focus on the task.
+            Modify(AddSystemPrompt(), text_field="body"),
+            # Add a period to the end of each email category, which makes PEFT easier.
+            Modify(AddPeriod(), text_field="category"),
+        ]
+    )
+
+    dataset = curation_steps(dataset)
+    dataset = dataset.persist()
+
+    print(f"    Original dataset length: {len(orig_dataset.df)}")
+    print(f"    After running the curation pipeline: {len(dataset.df)}")
+    print(f"    Writing to '{jsonl_fp}'...")
+    out_path = os.path.join(
+        os.path.dirname(jsonl_fp),
+        "curated",
+    )
+    os.makedirs(out_path, exist_ok=True)
+    dataset.to_json(out_path, write_to_filename=True)
+    client.close()
+    return os.path.join(out_path, os.path.basename(jsonl_fp))
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser = add_distributed_args(parser)
+    args = parser.parse_args()
+    # Limit the total number of workers to ensure we don't run out of memory.
+    args.n_workers = min(args.n_workers, 8)
+
+    # Prepare the download and JSONL directories.
+    if not os.path.isdir(DATA_DIR):
+        os.makedirs(DATA_DIR)
+
+    jsonl_fp = download_and_convert_to_jsonl()
+    run_curation_pipeline(args, jsonl_fp)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tutorials/peft-curation/modifiers.py b/tutorials/peft-curation/modifiers.py
new file mode 100644
index 000000000..059036ee4
--- /dev/null
+++ b/tutorials/peft-curation/modifiers.py
@@ -0,0 +1,68 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nemo_curator.modifiers import DocumentModifier
+
+# The system prompt template to be inserted into the documents.
+SYS_PROMPT_TEMPLATE = """[INST] <<SYS>> You are reviewing the contents of an email. Based on the content, please categorize this email into one of the following categories:
+1. 'Company Business/Strategy.'
+2. 'Purely Personal.'
+3. 'Personal but in a professional context.'
+4. 'Logistic Arrangements.'
+5. 'Employment arrangements.'
+6. 'Document editing/checking/collaboration.'
+Please provide only one category (e.g., 'Purely Personal.'). <</SYS>>
+
+Content::
+%s
+
+What should this email be categorized as?
+[/INST]
+Answer:: """
+
+
+class AddSystemPrompt(DocumentModifier):
+    """
+    A simple modifier that adds system prompts to each document.
+    """
+
+    def modify_document(self, text: str) -> str:
+        """
+        Inserts system prompts into the document.
+
+        Args:
+            text (str): The text to be modified.
+
+        Returns:
+            str: The modified text.
+        """
+        return SYS_PROMPT_TEMPLATE % text
+
+
+class AddPeriod(DocumentModifier):
+    """
+    A simple modifier that adds a period to the end of each email category.
+    """
+
+    def modify_document(self, text: str) -> str:
+        """
+        Adds a period to the end of each email category.
+
+        Args:
+            text (str): The text to be modified.
+
+        Returns:
+            str: The modified text.
+        """
+        return text + "."
diff --git a/tutorials/tinystories/README.md b/tutorials/tinystories/README.md
index 47074cb3f..45bc3bf33 100644
--- a/tutorials/tinystories/README.md
+++ b/tutorials/tinystories/README.md
@@ -1,6 +1,6 @@
 # TinyStories
 
-This tutorial demonstrates the usage of NeMo Curator's Python API to curate the [TinyStories](https://arxiv.org/abs/2305.07759) dataset. TinyStories is a dataset of short stories generated by GPT-3.5 and GPT-4, featuring words that are undersood by 3 to 4-year olds. The small size of this dataset makes it ideal for creating and validating data curation pipelines on a local machine.
+This tutorial demonstrates the usage of NeMo Curator's Python API to curate the [TinyStories](https://arxiv.org/abs/2305.07759) dataset. TinyStories is a dataset of short stories generated by GPT-3.5 and GPT-4, featuring words that are understood by 3 to 4-year olds. The small size of this dataset makes it ideal for creating and validating data curation pipelines on a local machine.
 
 For simplicity, this tutorial uses the validation split of this dataset, which contains around 22,000 samples.
 
diff --git a/tutorials/tinystories/main.py b/tutorials/tinystories/main.py
index fa4470c35..1fbbba35c 100644
--- a/tutorials/tinystories/main.py
+++ b/tutorials/tinystories/main.py
@@ -97,19 +97,23 @@ def filter_dataset(dataset: DocumentDataset) -> DocumentDataset:
                 WordCountFilter(min_words=80),
                 text_field="text",
                 score_field="word_count",
+                score_type=int,
             ),
-            ScoreFilter(IncompleteStoryFilter(), text_field="text"),
+            ScoreFilter(IncompleteStoryFilter(), text_field="text", score_type=bool),
             ScoreFilter(
                 RepeatingTopNGramsFilter(n=2, max_repeating_ngram_ratio=0.2),
                 text_field="text",
+                score_type=float,
             ),
             ScoreFilter(
                 RepeatingTopNGramsFilter(n=3, max_repeating_ngram_ratio=0.18),
                 text_field="text",
+                score_type=float,
             ),
             ScoreFilter(
                 RepeatingTopNGramsFilter(n=4, max_repeating_ngram_ratio=0.16),
                 text_field="text",
+                score_type=float,
             ),
         ]
     )