From 9f8578b49bcd7b2f57dfc9cb6adf38fffdfc75b8 Mon Sep 17 00:00:00 2001
From: Vibhu Jawa <vibhujawa@gmail.com>
Date: Tue, 21 May 2024 15:18:44 -0700
Subject: [PATCH] [REVIEW] Switch Models to use Crossfit (#58)

Switch Models to use Crossfit
---
 .../DistributedDataClassification.rst         |   4 +-
 ...xample.py => domain_classifier_example.py} |   4 +-
 ...ample.py => quality_classifier_example.py} |   6 +-
 .../__init__.py                               |  13 -
 .../arg_utils.py                              | 163 -------
 .../domain_classifier_inference.py            | 275 ------------
 .../generate_statistics.py                    |  82 ----
 .../pytorch_utils.py                          | 105 -----
 .../quality_classifier_inference.py           | 314 -------------
 ...ty_classifier_multiple_models_inference.py | 155 -------
 .../modules/distributed_data_classifier.py    | 425 ++++++++++--------
 .../scripts/domain_classifier_inference.py    | 132 ++++++
 .../scripts/quality_classifier_inference.py   | 134 ++++++
 .../verify_classification_results.py}         |   0
 nemo_curator/utils/script_utils.py            |  76 ++++
 setup.py                                      |  10 +-
 .../distributed_data_classification.ipynb     | 376 ++++++++++++++++
 17 files changed, 964 insertions(+), 1310 deletions(-)
 rename examples/{distributed_data_classification_examples/domain_api_example.py => domain_classifier_example.py} (97%)
 rename examples/{distributed_data_classification_examples/quality_api_example.py => quality_classifier_example.py} (96%)
 delete mode 100644 nemo_curator/distributed_data_classification/__init__.py
 delete mode 100644 nemo_curator/distributed_data_classification/arg_utils.py
 delete mode 100644 nemo_curator/distributed_data_classification/domain_classifier_inference.py
 delete mode 100644 nemo_curator/distributed_data_classification/generate_statistics.py
 delete mode 100644 nemo_curator/distributed_data_classification/pytorch_utils.py
 delete mode 100644 nemo_curator/distributed_data_classification/quality_classifier_inference.py
 delete mode 100644 nemo_curator/distributed_data_classification/quality_classifier_multiple_models_inference.py
 create mode 100644 nemo_curator/scripts/domain_classifier_inference.py
 create mode 100644 nemo_curator/scripts/quality_classifier_inference.py
 rename nemo_curator/{distributed_data_classification/verify_results.py => scripts/verify_classification_results.py} (100%)
 create mode 100644 tutorials/distributed_data_classification/distributed_data_classification.ipynb

diff --git a/docs/user-guide/DistributedDataClassification.rst b/docs/user-guide/DistributedDataClassification.rst
index f2bf098d3..5e69cf143 100644
--- a/docs/user-guide/DistributedDataClassification.rst
+++ b/docs/user-guide/DistributedDataClassification.rst
@@ -49,13 +49,13 @@ Let's see how ``DomainClassifier`` works in a small excerpt taken from ``example
         "Travel_and_Transportation",
     ]
 
-    model_file_name = "pytorch_model_file.pth"
+    model_path = "pytorch_model_file.pth"
 
     files = get_all_files_paths_under("books_dataset/")
     input_dataset = DocumentDataset.read_json(files, backend="cudf", add_filename=True)
 
     domain_classifier = DomainClassifier(
-        model_file_name=model_file_name,
+        model_path=model_path,
         labels=labels,
         filter_by=["Games", "Sports"],
     )
diff --git a/examples/distributed_data_classification_examples/domain_api_example.py b/examples/domain_classifier_example.py
similarity index 97%
rename from examples/distributed_data_classification_examples/domain_api_example.py
rename to examples/domain_classifier_example.py
index ad2fa1c8b..83ac4e63d 100644
--- a/examples/distributed_data_classification_examples/domain_api_example.py
+++ b/examples/domain_classifier_example.py
@@ -53,7 +53,7 @@ def main(args):
         "Travel_and_Transportation",
     ]
 
-    model_file_name = "/path/to/pytorch_model_file.pth"
+    model_path = "/path/to/pytorch_model_file.pth"
 
     # Input can be a string or list
     input_file_path = "/path/to/data"
@@ -66,7 +66,7 @@ def main(args):
     )
 
     domain_classifier = DomainClassifier(
-        model_file_name=model_file_name,
+        model_path=model_path,
         labels=labels,
         filter_by=["Games", "Sports"],
     )
diff --git a/examples/distributed_data_classification_examples/quality_api_example.py b/examples/quality_classifier_example.py
similarity index 96%
rename from examples/distributed_data_classification_examples/quality_api_example.py
rename to examples/quality_classifier_example.py
index 53b9849c4..4a137f682 100644
--- a/examples/distributed_data_classification_examples/quality_api_example.py
+++ b/examples/quality_classifier_example.py
@@ -25,7 +25,7 @@ def main(args):
     global_st = time.time()
 
     labels = ["High", "Medium", "Low"]
-    model_file_name = "/path/to/pytorch_model_file.pth"
+    model_path = "/path/to/pytorch_model_file.pth"
 
     # Input can be a string or list
     input_file_path = "/path/to/data"
@@ -33,12 +33,12 @@ def main(args):
 
     client = get_client(args, cluster_type=args.device)
 
-    input_dataset = DocumentDataset.from_json(
+    input_dataset = DocumentDataset.read_json(
         input_file_path, backend="cudf", add_filename=True
     )
 
     quality_classifier = QualityClassifier(
-        model_file_name=model_file_name,
+        model_path=model_path,
         labels=labels,
         filter_by=["High", "Medium"],
     )
diff --git a/nemo_curator/distributed_data_classification/__init__.py b/nemo_curator/distributed_data_classification/__init__.py
deleted file mode 100644
index d9155f923..000000000
--- a/nemo_curator/distributed_data_classification/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/nemo_curator/distributed_data_classification/arg_utils.py b/nemo_curator/distributed_data_classification/arg_utils.py
deleted file mode 100644
index 967d1c421..000000000
--- a/nemo_curator/distributed_data_classification/arg_utils.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-
-os.environ["RAPIDS_NO_INITIALIZE"] = "1"
-import warnings
-
-warnings.filterwarnings("ignore")
-
-
-def add_input_output_args(parser):
-    """
-    This function adds the command line arguments related to input and output files.
-
-    Args:
-        parser: An argparse ArgumentParser object.
-    Returns:
-        An argparse ArgumentParser with 3 additional arguments.
-
-    """
-    parser.add_argument(
-        "--input_file_path",
-        type=str,
-        help="The path of the input files",
-        required=True,
-    )
-    parser.add_argument(
-        "--input_file_type",
-        type=str,
-        help="The type of the input files",
-        required=True,
-    )
-    parser.add_argument(
-        "--output_file_path",
-        type=str,
-        help="The path of the output files",
-        required=True,
-    )
-    return parser
-
-
-def add_cluster_args(parser):
-    """
-    This function adds the command line arguments related to Dask cluster setup.
-
-    Args:
-        parser: An argparse ArgumentParser object.
-    Returns:
-        An argparse ArgumentParser with 8 additional arguments.
-
-    """
-    parser.add_argument(
-        "--scheduler-address",
-        type=str,
-        default=None,
-        help="""Address to the scheduler of a created Dask cluster.
-                If not provided, a single node LocalCUDACluster will be started.""",
-    )
-    parser.add_argument(
-        "--scheduler-file",
-        type=str,
-        default=None,
-        help="""Path to the scheduler file of a created Dask cluster.
-                If not provided, a single node LocalCUDACluster will be started.""",
-    )
-    parser.add_argument(
-        "--protocol",
-        type=str,
-        default="ucx",
-        help="""Protocol to use for Dask cluster.
-                Note: This only applies to the LocalCUDACluster.
-                If providing a user created cluster, refer to
-                https://docs.rapids.ai/api/dask-cuda/stable/api.html#cmdoption-dask-cuda-protocol""",
-    )
-    parser.add_argument(
-        "--nvlink-only",
-        action="store_true",
-        help="""Start a local cluster with only NVLink enabled.
-                Only applicable when protocol=ucx and no scheduler file/address is specified.""",
-    )
-    parser.add_argument(
-        "--rmm_pool_size",
-        type=str,
-        help="The size of the RMM pool to be used by each worker.",
-        default="14GB",
-    )
-    parser.add_argument(
-        "--CUDA_VISIBLE_DEVICES",
-        type=str,
-        help="The GPUs to be used by the cluster.",
-        default=None,
-    )
-    parser.add_argument("--enable_spilling", action="store_true")
-    parser.add_argument("--set_torch_to_use_rmm", action="store_true")
-    return parser
-
-
-def add_model_args(parser):
-    """
-    This function adds the command line arguments related to the model.
-
-    Args:
-        parser: An argparse ArgumentParser object.
-    Returns:
-        An argparse ArgumentParser with 4 additional arguments.
-
-    """
-    # Add a mutually exclusive group for model_file_name and model_file_names
-    group = parser.add_mutually_exclusive_group(required=True)
-    group.add_argument(
-        "--model_file_name",
-        type=str,
-        help="The path to the model file",
-        required=False,
-    )
-    group.add_argument(
-        "--model_file_names",
-        type=str,
-        nargs="*",
-        help="A list of model file paths",
-        required=False,
-    )
-    parser.add_argument(
-        "--autocast",
-        action="store_true",
-        help="Whether to use autocast or not",
-    )
-    parser.add_argument(
-        "--batch_size",
-        type=int,
-        default=128,
-        help="The batch size to be used for inference",
-    )
-    return parser
-
-
-def create_arg_parser():
-    """
-    This function creates the argument parser to add the command line arguments.
-
-    Returns:
-        An argparse ArgumentParser object.
-
-    """
-    import argparse
-
-    parser = argparse.ArgumentParser(description="Run multi-node multi-GPU inference")
-    parser = add_cluster_args(parser)
-    parser = add_input_output_args(parser)
-    parser = add_model_args(parser)
-    return parser
diff --git a/nemo_curator/distributed_data_classification/domain_classifier_inference.py b/nemo_curator/distributed_data_classification/domain_classifier_inference.py
deleted file mode 100644
index 7fdeb6cfa..000000000
--- a/nemo_curator/distributed_data_classification/domain_classifier_inference.py
+++ /dev/null
@@ -1,275 +0,0 @@
-# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import time
-import warnings
-
-os.environ["RAPIDS_NO_INITIALIZE"] = "1"
-import torch
-from packaging import version
-from transformers import __version__ as TRANSFORMERS_VERSION
-from transformers.models.deberta_v2 import DebertaV2TokenizerFast
-
-from nemo_curator.distributed_data_classification.arg_utils import create_arg_parser
-from nemo_curator.distributed_data_classification.pytorch_utils import (
-    CFG,
-    CustomModel,
-    TestDataset,
-    collate,
-)
-from nemo_curator.utils.distributed_utils import (
-    get_client,
-    load_object_on_worker,
-    process_all_batches,
-    read_data,
-    write_to_disk,
-)
-from nemo_curator.utils.file_utils import get_remaining_files
-
-warnings.filterwarnings("ignore")
-
-
-def inference_per_partition(
-    df,
-    max_chars,
-    batch_size,
-    num_workers,
-    model_file_name,
-    labels,
-    autocast,
-):
-    """
-    This function runs domain classification on a subset of the data.
-    It loads the CFG, a data iterator, and then calls the `process_all_batches` function,
-    which loads the domain classifier and runs inference.
-
-    Args:
-        df: A Dask DataFrame partition with a "text" column and a dummy "pred" column.
-        max_chars: The maximum number of characters allowed in the truncated text.
-        batch_size: How many samples per batch to load with PyTorch DataLoader.
-        num_workers: How many subprocesses to use for PyTorch DataLoader.
-        model_file_name: The path to the model file.
-        labels: The list of domain labels.
-        autocast: A boolean representing whether to perform inference with mixed precision.
-    Returns:
-        The input Dask DataFrame with the calculated "pred" column.
-
-    """
-    cfg = cfg_per_partition()
-
-    dataset_valid = TestDataset(cfg, df, max_chars)
-    loader_valid = torch.utils.data.DataLoader(
-        dataset_valid,
-        batch_size=batch_size,
-        shuffle=False,
-        num_workers=num_workers,
-    )
-    device = torch.device("cuda")
-    load_model_kwargs = {"cfg": cfg, "device": device, "model_path": model_file_name}
-    run_inference_kwargs = {"autocast": autocast}
-    st = time.time()
-    preds = process_all_batches(
-        loader_valid,
-        load_model,
-        load_model_kwargs,
-        run_inference,
-        run_inference_kwargs,
-    )
-    preds = preds.cpu().numpy()
-    df["pred"] = [labels[i] for i in preds]
-
-    et = time.time()
-    print(
-        f"Time taken for inference for num_batches: {len(loader_valid)} : {et-st} s",
-        flush=True,
-    )
-
-    return df
-
-
-def cfg_per_partition():
-    """
-    This function loads the CFG on the worker currently running the task.
-    See `load_object_on_worker` function.
-
-    Returns:
-        A CFG with a set `tokenizer` attribute.
-
-    """
-    return load_object_on_worker("cfg_with_tokenizer", load_cfg_with_tokenizer, {})
-
-
-def load_cfg_with_tokenizer():
-    """
-    This function loads the CFG needed for domain classification.
-
-    Returns:
-        A CFG with a set `tokenizer` attribute.
-
-    """
-    cfg = CFG()
-    tokenizer = DebertaV2TokenizerFast.from_pretrained(cfg.model)
-    cfg.tokenizer = tokenizer
-    return cfg
-
-
-def load_model(cfg, device, model_path):
-    """
-    This function loads the domain model and prepares it to be used for inference.
-    It is needed as an input to the `process_all_batches` function within the `inference_per_partition` function.
-
-    Args:
-        cfg: A CFG object.
-        device: A specified PyTorch device, such as torch.device("cuda") or torch.device("cpu").
-        model_path: The path to the model file.
-    Returns:
-        The loaded model.
-
-    """
-    model = CustomModel(cfg, out_dim=27, config_path=None, pretrained=True)
-    model = model.to(device)
-    sd = torch.load(os.path.join(model_path), map_location="cpu")
-    sd = {k[7:] if k.startswith("module.") else k: sd[k] for k in sd.keys()}
-    if version.parse(TRANSFORMERS_VERSION) >= version.parse("4.31.0"):
-        sd.pop("model.embeddings.position_ids", None)
-
-    model.load_state_dict(sd, strict=True)
-    model.eval()
-    return model
-
-
-def run_inference(batch, model, autocast=False):
-    """
-    This function runs the domain classifier on a batch of data.
-    It is needed as an input to the `process_all_batches` function within the `inference_per_partition` function.
-
-    Args:
-        batch: A subset of the data as we are iterating through PyTorch DataLoader.
-        model: The loaded domain classification model.
-        autocast: A boolean representing whether to perform inference with mixed precision.
-    Returns:
-        A tensor of predictions.
-
-    """
-    with torch.no_grad():
-        batch = collate(batch)
-        if autocast:
-            with torch.autocast(device_type="cuda"):
-                out = model(batch)[:, 0, :]
-        else:
-            out = model(batch)[:, 0, :]
-        pred_idx = torch.sigmoid(out).argmax(1)
-
-    return pred_idx
-
-
-def main():
-    labels = [
-        "Adult",
-        "Arts_and_Entertainment",
-        "Autos_and_Vehicles",
-        "Beauty_and_Fitness",
-        "Books_and_Literature",
-        "Business_and_Industrial",
-        "Computers_and_Electronics",
-        "Finance",
-        "Food_and_Drink",
-        "Games",
-        "Health",
-        "Hobbies_and_Leisure",
-        "Home_and_Garden",
-        "Internet_and_Telecom",
-        "Jobs_and_Education",
-        "Law_and_Government",
-        "News",
-        "Online_Communities",
-        "People_and_Society",
-        "Pets_and_Animals",
-        "Real_Estate",
-        "Reference",
-        "Science",
-        "Sensitive_Subjects",
-        "Shopping",
-        "Sports",
-        "Travel_and_Transportation",
-    ]
-
-    args = create_arg_parser().parse_args()
-    print(f"Arguments parsed = {args}", flush=True)
-    max_chars = 2000
-    batch_size = args.batch_size
-    num_workers = 0
-
-    client = get_client(args, cluster_type="gpu")
-    print("Starting domain classifier inference", flush=True)
-    global_st = time.time()
-    files_per_run = len(client.scheduler_info()["workers"]) * 2
-    input_files = get_remaining_files(
-        args.input_file_path, args.output_file_path, args.input_file_type
-    )
-    print(f"Total input files {len(input_files)}", flush=True)
-
-    if args.input_file_type == "pickle":
-        add_filename = False
-    else:
-        add_filename = True
-
-    for file_batch_id, i in enumerate(range(0, len(input_files), files_per_run)):
-        batch_st = time.time()
-        current_batch_files = input_files[i : i + files_per_run]
-        print(
-            f"File Batch ID {file_batch_id}: total input files {len(current_batch_files)}",
-            flush=True,
-        )
-        df = read_data(
-            input_files=current_batch_files,
-            file_type=args.input_file_type,
-            add_filename=add_filename,
-        )
-        print(f"Total input Dask DataFrame partitions {df.npartitions}", flush=True)
-        meta_df = df._meta.copy()
-        meta_df["pred"] = [0] * len(meta_df)
-        df = df.map_partitions(
-            inference_per_partition,
-            max_chars,
-            batch_size,
-            num_workers,
-            args.model_file_name,
-            labels,
-            args.autocast,
-            meta=meta_df,
-            enforce_metadata=False,
-        )
-        write_to_disk(
-            df=df,
-            output_file_dir=args.output_file_path,
-            write_to_filename=add_filename,
-        )
-        batch_et = time.time()
-        print(
-            f"File Batch ID {file_batch_id}: completed in {batch_et-batch_st} seconds",
-            flush=True,
-        )
-
-    global_et = time.time()
-    print(
-        f"Total time taken for domain classifier inference: {global_et-global_st} s",
-        flush=True,
-    )
-    client.close()
-
-
-def console_script():
-    main()
diff --git a/nemo_curator/distributed_data_classification/generate_statistics.py b/nemo_curator/distributed_data_classification/generate_statistics.py
deleted file mode 100644
index 17e76c92d..000000000
--- a/nemo_curator/distributed_data_classification/generate_statistics.py
+++ /dev/null
@@ -1,82 +0,0 @@
-# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import argparse
-import time
-
-from nemo_curator.distributed_data_classification.arg_utils import (
-    add_cluster_args,
-    add_input_output_args,
-)
-from nemo_curator.utils.distributed_utils import get_client, read_data
-from nemo_curator.utils.file_utils import get_all_files_paths_under
-
-
-def value_counts(df, column_name):
-    """
-    This function groups a DataFrame by the specified column and counts the occurrences of each group.
-    It is essentially the same as pandas.Series.value_counts, except it returns a DataFrame.
-
-    Args:
-        df: A DataFrame.
-        column_name: The column by which to group the DataFrame.
-    Returns:
-        A DataFrame with two columns: column_name and a second column containing the counts per group.
-
-    """
-    return df.groupby(column_name).size().reset_index()
-
-
-def main():
-    parser = argparse.ArgumentParser(
-        description="Generate label statistics and write them to disk"
-    )
-
-    parser = add_cluster_args(parser)
-    parser = add_input_output_args(parser)
-    parser.add_argument(
-        "--label",
-        type=str,
-        help="The label column on which to generate statistics",
-        required=True,
-    )
-    args = parser.parse_args()
-    print(f"Arguments parsed = {args}", flush=True)
-    client = get_client(args, cluster_type="gpu")
-
-    print("Starting statistics workflow", flush=True)
-    st = time.time()
-
-    df = read_data(
-        input_files=get_all_files_paths_under(
-            args.input_file_path, recurse_subdirecties=False
-        ),
-        file_type=args.input_file_type,
-        add_filename=True,
-    )
-    input_files = get_all_files_paths_under(
-        args.input_file_path, recurse_subdirecties=False
-    )
-
-    result = value_counts(df, column_name=args.label)
-    result = result.rename(columns={0: "count"})
-    result.to_json(args.output_file_path)
-
-    et = time.time()
-    print(f"Statistics workflow completed in {et-st}", flush=True)
-    client.close()
-
-
-def console_script():
-    main()
diff --git a/nemo_curator/distributed_data_classification/pytorch_utils.py b/nemo_curator/distributed_data_classification/pytorch_utils.py
deleted file mode 100644
index 58d9b5bf5..000000000
--- a/nemo_curator/distributed_data_classification/pytorch_utils.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-
-os.environ["RAPIDS_NO_INITIALIZE"] = "1"
-import torch
-import torch.nn as nn
-from torch.utils.data import Dataset
-from transformers import AutoConfig, AutoModel
-
-
-class CFG:
-    model = "microsoft/deberta-v3-base"
-    fc_dropout = 0.2
-
-    def __init__(self, max_len=512):
-        self.max_len = max_len
-
-
-def collate(inputs):
-    inputs = {k: v.to("cuda") for k, v in inputs.items()}
-    mask_len = int(inputs["attention_mask"].sum(axis=1).max())
-    for k, v in inputs.items():
-        # CPMP: no need to truncate labels
-        if k != "labels":
-            inputs[k] = inputs[k][:, :mask_len]
-    return inputs
-
-
-class CustomModel(nn.Module):
-    def __init__(self, cfg, out_dim, config_path=None, pretrained=False):
-        super().__init__()
-        self.cfg = cfg
-        if config_path is None:
-            self.config = AutoConfig.from_pretrained(
-                cfg.model, output_hidden_states=True
-            )
-        else:
-            self.config = torch.load(config_path)
-        if pretrained:
-            self.model = AutoModel.from_pretrained(cfg.model, config=self.config)
-        else:
-            self.model = AutoModel(self.config)
-        self.fc_dropout = nn.Dropout(cfg.fc_dropout)
-        self.fc = nn.Linear(self.config.hidden_size, out_dim)
-        self._init_weights(self.fc)
-
-    def _init_weights(self, module):
-        if isinstance(module, nn.Linear):
-            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
-            if module.bias is not None:
-                module.bias.data.zero_()
-        elif isinstance(module, nn.Embedding):
-            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
-            if module.padding_idx is not None:
-                module.weight.data[module.padding_idx].zero_()
-        elif isinstance(module, nn.LayerNorm):
-            module.bias.data.zero_()
-            module.weight.data.fill_(1.0)
-
-    def feature(self, input_ids, attention_mask):
-        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
-        last_hidden_states = outputs[0]
-        return last_hidden_states
-
-    def forward(self, batch):
-        feature = self.feature(batch["input_ids"], batch["attention_mask"])
-        output = self.fc(self.fc_dropout(feature))
-        return output
-
-
-class TestDataset(Dataset):
-    def __init__(self, cfg, df, max_chars):
-        self.cfg = cfg
-        text = df["text"].str.slice(0, max_chars).to_arrow().to_pylist()
-        with torch.no_grad():
-            self.tokens = cfg.tokenizer.batch_encode_plus(
-                text,
-                return_tensors="pt",
-                add_special_tokens=True,
-                max_length=cfg.max_len,
-                pad_to_max_length=True,
-                truncation=True,
-                return_token_type_ids=False,
-            )
-        self.max_chars = max_chars
-        self.dataset_len = len(text)
-
-    def __len__(self):
-        return self.dataset_len
-
-    def __getitem__(self, item):
-        return {k: v[item] for k, v in self.tokens.items()}
diff --git a/nemo_curator/distributed_data_classification/quality_classifier_inference.py b/nemo_curator/distributed_data_classification/quality_classifier_inference.py
deleted file mode 100644
index 90d841346..000000000
--- a/nemo_curator/distributed_data_classification/quality_classifier_inference.py
+++ /dev/null
@@ -1,314 +0,0 @@
-# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import time
-import warnings
-
-os.environ["RAPIDS_NO_INITIALIZE"] = "1"
-import torch
-from transformers.models.deberta_v2 import DebertaV2TokenizerFast
-
-from nemo_curator.distributed_data_classification.arg_utils import create_arg_parser
-from nemo_curator.distributed_data_classification.pytorch_utils import (
-    CFG,
-    CustomModel,
-    TestDataset,
-    collate,
-)
-from nemo_curator.utils.distributed_utils import (
-    get_client,
-    load_object_on_worker,
-    process_all_batches,
-    read_data,
-    write_to_disk,
-)
-from nemo_curator.utils.file_utils import get_remaining_files
-
-warnings.filterwarnings("ignore")
-
-
-def inference_per_partition(
-    df,
-    max_chars,
-    batch_size,
-    num_workers,
-    model_file_name,
-    labels,
-    autocast,
-    include_model_name=False,
-):
-    """
-    This function runs quality classification on a subset of the data.
-    It loads the CFG, a data iterator, and then calls the `process_all_batches` function,
-    which loads the quality classifier and runs inference.
-    It also contains some additional logic to handle binary versus multiclass classification.
-
-    Args:
-        df: A Dask DataFrame partition with a "text" column and a dummy "quality_pred" column.
-        max_chars: The maximum number of characters allowed in the truncated text.
-        batch_size: How many samples per batch to load with PyTorch DataLoader.
-        num_workers: How many subprocesses to use for PyTorch DataLoader.
-        model_file_name: The path to the model file.
-        labels: The list of domain labels.
-        autocast: A boolean representing whether to perform inference with mixed precision.
-        include_model_name: A boolean representing whether to include the model name in the "quality_pred" column name.
-    Returns:
-        The input Dask DataFrame with the calculated "quality_pred" column.
-
-    """
-    cfg = cfg_per_partition()
-
-    dataset_valid = TestDataset(cfg, df, max_chars)
-    loader_valid = torch.utils.data.DataLoader(
-        dataset_valid,
-        batch_size=batch_size,
-        shuffle=False,
-        num_workers=num_workers,
-    )
-    device = torch.device("cuda")
-    if len(labels) == 1:
-        raise ValueError("Labels must be more than 1")
-
-    # binary case
-    if len(labels) == 2:
-        out_dim = 1
-        binary_classification = True
-    else:
-        out_dim = len(labels)
-        binary_classification = False
-
-    load_model_kwargs = {
-        "cfg": cfg,
-        "device": device,
-        "model_path": model_file_name,
-        "out_dim": out_dim,
-    }
-    run_inference_kwargs = {
-        "autocast": autocast,
-        "binary_classification": binary_classification,
-    }
-    st = time.time()
-    probs = process_all_batches(
-        loader_valid,
-        load_model,
-        load_model_kwargs,
-        run_inference,
-        run_inference_kwargs,
-    )
-    if binary_classification:
-        preds = (probs > 0.5).to(torch.int64).squeeze()
-    else:
-        preds = torch.argmax(probs, dim=1)
-    # TODO: Do this without a CPU roundtrip in the future
-    if include_model_name:
-        filename = os.path.basename(model_file_name)
-        df[f"quality_pred_{filename}"] = [
-            labels[i] for i in preds.to("cpu").numpy().tolist()
-        ]
-        df[f"quality_prob_{filename}"] = probs.to("cpu").numpy().tolist()
-    else:
-        df["quality_pred"] = [labels[i] for i in preds.to("cpu").numpy().tolist()]
-        df["quality_prob"] = probs.to("cpu").numpy().tolist()
-    et = time.time()
-    print(
-        f"Time taken for inference for num_batches: {len(loader_valid)} : {et-st} s",
-        flush=True,
-    )
-
-    return df
-
-
-def cfg_per_partition():
-    """
-    This function loads the CFG on the worker currently running the task.
-    See `load_object_on_worker` function.
-
-    Returns:
-        A CFG with a set `tokenizer` attribute.
-
-    """
-    return load_object_on_worker("cfg_with_tokenizer", load_cfg_with_tokenizer, {})
-
-
-def load_cfg_with_tokenizer():
-    """
-    This function loads the CFG needed for quality classification.
-
-    Returns:
-        A CFG with a set `tokenizer` attribute.
-
-    """
-    cfg = CFG(max_len=1024)
-    tokenizer = DebertaV2TokenizerFast.from_pretrained(cfg.model)
-    cfg.tokenizer = tokenizer
-    return cfg
-
-
-def load_model(cfg, device, model_path, out_dim):
-    """
-    This function loads the quality model and prepares it to be used for inference.
-    It is needed as an input to the `process_all_batches` function within the `inference_per_partition` function.
-
-    Args:
-        cfg: A CFG object.
-        device: A specified PyTorch device, such as torch.device("cuda") or torch.device("cpu").
-        model_path: The path to the model file.
-        out_dim: An integer which corresponds to the number of labels. Use 1 for binary classification.
-    Returns:
-        The loaded model.
-
-    """
-    model = CustomModel(cfg, out_dim=out_dim, config_path=None, pretrained=True)
-    model = model.to(device)
-    sd = torch.load(model_path, map_location="cpu")
-    if "model_state_dict" in sd:
-        sd = sd["model_state_dict"]
-    sd = {k[7:] if k.startswith("module.") else k: sd[k] for k in sd.keys()}
-    model.load_state_dict(sd, strict=True)
-    model.eval()
-    return model
-
-
-def run_inference(batch, model, autocast=False, binary_classification=False):
-    """
-    This function runs the quality classifier on a batch of data.
-    It is needed as an input to the `process_all_batches` function within the `inference_per_partition` function.
-
-    Args:
-        batch: A subset of the data as we are iterating through PyTorch DataLoader.
-        model: The loaded quality classification model.
-        autocast: A boolean representing whether to perform inference with mixed precision.
-        binary_classification: A boolean representing whether it is a binary classification model.
-    Returns:
-        A tensor of predictions.
-
-    """
-    with torch.no_grad():
-        batch = collate(batch)
-        if autocast:
-            with torch.autocast(device_type="cuda"):
-                out = model(batch)[:, 0, :]
-        else:
-            out = model(batch)[:, 0, :]
-        if binary_classification:
-            probs = torch.sigmoid(out)
-        else:
-            probs = torch.softmax(out, dim=1)
-    return probs
-
-
-def add_quality_model_specific_args(parser):
-    """
-    This function adds a command line argument for the number of labels.
-
-    Args:
-        parser: An argparse ArgumentParser object.
-    Returns:
-        An argparse ArgumentParser with 1 additional argument.
-
-    """
-    parser.add_argument("--num-labels", type=int, default=3)
-    return parser
-
-
-def get_labels(num_labels):
-    """
-    This function returns a list of quality labels, depending on how many labels the user expects.
-
-    Args:
-        num_labels: An integer representing the number of possible classification labels.
-    Returns:
-        A list of label names.
-
-    """
-    if num_labels == 3:
-        labels = ["High", "Medium", "Low"]
-    elif num_labels == 2:
-        labels = ["Medium_High", "Low"]
-    return labels
-
-
-def main():
-    parser = create_arg_parser()
-    parser = add_quality_model_specific_args(parser)
-    args = parser.parse_args()
-    labels = get_labels(args.num_labels)
-    print(f"Arguments parsed = {args}", flush=True)
-    max_chars = 6000
-    batch_size = args.batch_size
-    num_workers = 0
-
-    client = get_client(args, cluster_type="gpu")
-    print("Starting quality classifier inference", flush=True)
-    global_st = time.time()
-    files_per_run = len(client.scheduler_info()["workers"]) * 2
-    input_files = get_remaining_files(
-        args.input_file_path, args.output_file_path, args.input_file_type
-    )
-    print(f"Total input files {len(input_files)}", flush=True)
-
-    if args.input_file_type == "pickle":
-        add_filename = False
-    else:
-        add_filename = True
-
-    for file_batch_id, i in enumerate(range(0, len(input_files), files_per_run)):
-        batch_st = time.time()
-        current_batch_files = input_files[i : i + files_per_run]
-        print(
-            f"File Batch ID {file_batch_id}: total input files {len(current_batch_files)}",
-            flush=True,
-        )
-        df = read_data(
-            input_files=current_batch_files,
-            file_type=args.input_file_type,
-            add_filename=add_filename,
-        )
-        print(f"Total input Dask DataFrame partitions {df.npartitions}", flush=True)
-        meta_df = df._meta.copy()
-        meta_df["quality_pred"] = ["low"] * len(meta_df)
-        meta_df["quality_prob"] = [[0, 0, 1]] * len(meta_df)
-        df = df.map_partitions(
-            inference_per_partition,
-            max_chars,
-            batch_size,
-            num_workers,
-            args.model_file_name,
-            labels,
-            args.autocast,
-            meta=meta_df,
-            enforce_metadata=False,
-        )
-        write_to_disk(
-            df=df,
-            output_file_dir=args.output_file_path,
-            write_to_filename=add_filename,
-        )
-        batch_et = time.time()
-        print(
-            f"File Batch ID {file_batch_id}: completed in {batch_et-batch_st} seconds",
-            flush=True,
-        )
-
-    global_et = time.time()
-    print(
-        f"Total time taken for quality classifier inference: {global_et-global_st} s",
-        flush=True,
-    )
-    client.close()
-
-
-def console_script():
-    main()
diff --git a/nemo_curator/distributed_data_classification/quality_classifier_multiple_models_inference.py b/nemo_curator/distributed_data_classification/quality_classifier_multiple_models_inference.py
deleted file mode 100644
index 403741c2e..000000000
--- a/nemo_curator/distributed_data_classification/quality_classifier_multiple_models_inference.py
+++ /dev/null
@@ -1,155 +0,0 @@
-# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import time
-
-from dask.distributed import wait
-
-from nemo_curator.distributed_data_classification.arg_utils import create_arg_parser
-from nemo_curator.distributed_data_classification.quality_classifier_inference import (
-    add_quality_model_specific_args,
-    get_labels,
-    inference_per_partition,
-)
-from nemo_curator.utils.distributed_utils import (
-    get_client,
-    offload_object_on_worker,
-    read_data,
-    write_to_disk,
-)
-from nemo_curator.utils.file_utils import get_remaining_files
-
-
-def delete_model_and_tokenizer_from_workers(client):
-    """
-    Offloads cfg_with_tokenizer and model from all Dask client workers.
-
-    Args:
-        client: A Dask client object.
-
-    """
-    task_ls = []
-    # TODO: client.run does not work anymore
-    # See: https://dask.discourse.group/t/cannot-run-client-run-function-when-function-contains-get-worker-in-distributed-2023-3-2-1/1772
-    # find a better alternate
-    for worker in client.scheduler_info()["workers"]:
-        task_ls.append(
-            client.submit(
-                offload_object_on_worker,
-                "cfg_with_tokenizer",
-                workers=[worker],
-                allow_other_workers=False,
-                pure=False,
-            )
-        )
-        task_ls.append(
-            client.submit(
-                offload_object_on_worker,
-                "model",
-                workers=[worker],
-                allow_other_workers=False,
-                pure=False,
-            )
-        )
-    wait(task_ls)
-    for t in task_ls:
-        assert t.result() == True
-    del task_ls
-
-
-def main():
-    parser = create_arg_parser()
-    parser = add_quality_model_specific_args(parser)
-    args = parser.parse_args()
-    labels = get_labels(args.num_labels)
-    print(f"Arguments parsed = {args}", flush=True)
-
-    max_chars = 6000
-    batch_size = args.batch_size
-    num_workers = 0
-    client = get_client(args, cluster_type="gpu")
-    client.upload_file("quality_classifier_inference.py")
-
-    print("Starting quality classifier inference", flush=True)
-    global_st = time.time()
-    files_per_run = len(client.scheduler_info()["workers"]) * 2
-    input_files = get_remaining_files(
-        args.input_file_path, args.output_file_path, args.input_file_type
-    )
-    print(f"Total input files {len(input_files)}", flush=True)
-
-    if args.input_file_type == "pickle":
-        add_filename = False
-    else:
-        add_filename = True
-
-    for file_batch_id, i in enumerate(range(0, len(input_files), files_per_run)):
-        batch_st = time.time()
-        current_batch_files = input_files[i : i + files_per_run]
-        print(
-            f"File Batch ID {file_batch_id}: total input files {len(current_batch_files)}",
-            flush=True,
-        )
-        df = read_data(
-            input_files=current_batch_files,
-            file_type=args.input_file_type,
-            add_filename=add_filename,
-        )
-        print(f"Total input Dask DataFrame partitions {df.npartitions}", flush=True)
-
-        for model_file_path in args.model_file_names:
-            meta_df = df._meta.copy()
-            model_file_name = os.path.basename(model_file_path)
-            print(f"model_file_name={model_file_name}", flush=True)
-            print("--" * 30, flush=True)
-            meta_df[f"quality_pred_{model_file_name}"] = ["low"] * len(meta_df)
-            meta_df[f"quality_prob_{model_file_name}"] = [[0, 0, 1]] * len(meta_df)
-            df = df.map_partitions(
-                inference_per_partition,
-                max_chars,
-                batch_size,
-                num_workers,
-                model_file_path,
-                labels,
-                args.autocast,
-                include_model_name=True,
-                meta=meta_df,
-                enforce_metadata=False,
-            )
-            df = df.persist()
-            wait(df)
-            delete_model_and_tokenizer_from_workers(client)
-
-        write_to_disk(
-            df=df,
-            output_file_dir=args.output_file_path,
-            write_to_filename=add_filename,
-        )
-        batch_et = time.time()
-        print(
-            f"File Batch ID {file_batch_id}: completed in {batch_et-batch_st} seconds",
-            flush=True,
-        )
-
-    global_et = time.time()
-    print(
-        f"Total time taken for multiple quality classifier inference models: {global_et-global_st} s",
-        flush=True,
-    )
-    client.close()
-
-
-def console_script():
-    main()
diff --git a/nemo_curator/modules/distributed_data_classifier.py b/nemo_curator/modules/distributed_data_classifier.py
index e45d0ba56..6bb975eec 100644
--- a/nemo_curator/modules/distributed_data_classifier.py
+++ b/nemo_curator/modules/distributed_data_classifier.py
@@ -13,24 +13,88 @@
 # limitations under the License.
 
 import os
+
+os.environ["RAPIDS_NO_INITIALIZE"] = "1"
 from abc import ABC, abstractmethod
+from dataclasses import dataclass
 
 import torch
+import torch.nn as nn
+from crossfit import op
+from crossfit.backend.torch.hf.model import HFModel
 from packaging import version
+from transformers import AutoConfig, AutoModel
 from transformers import __version__ as TRANSFORMERS_VERSION
 from transformers.models.deberta_v2 import DebertaV2TokenizerFast
 
 from nemo_curator.datasets import DocumentDataset
-from nemo_curator.distributed_data_classification.pytorch_utils import (
-    CFG,
-    CustomModel,
-    TestDataset,
-    collate,
-)
-from nemo_curator.utils.distributed_utils import (
-    load_object_on_worker,
-    process_all_batches,
-)
+
+
+@dataclass
+class DomainModelConfig:
+    model = "microsoft/deberta-v3-base"
+    fc_dropout = 0.2
+    max_len = 512
+
+
+@dataclass
+class QualityModelConfig:
+    model = "microsoft/deberta-v3-base"
+    fc_dropout = 0.2
+    max_len = 512
+
+
+class CustomModel(nn.Module):
+    def __init__(
+        self, config, out_dim, config_path=None, pretrained=False, autocast=False
+    ):
+        super().__init__()
+        self.config = config
+        if config_path is None:
+            self.config = AutoConfig.from_pretrained(
+                config.model, output_hidden_states=True
+            )
+        else:
+            self.config = torch.load(config_path)
+        if pretrained:
+            self.model = AutoModel.from_pretrained(config.model, config=self.config)
+        else:
+            self.model = AutoModel(self.config)
+        self.fc_dropout = nn.Dropout(config.fc_dropout)
+        self.fc = nn.Linear(self.config.hidden_size, out_dim)
+        self._init_weights(self.fc)
+        self.autocast = autocast
+
+    def _init_weights(self, module):
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+    def feature(self, input_ids, attention_mask):
+        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
+        last_hidden_states = outputs[0]
+        return last_hidden_states
+
+    def _forward(self, batch):
+        feature = self.feature(batch["input_ids"], batch["attention_mask"])
+        output = self.fc(self.fc_dropout(feature))
+        output = output.to(torch.float32)
+        return torch.softmax(output[:, 0, :], dim=1)
+
+    def forward(self, batch):
+        if self.autocast:
+            with torch.autocast(device_type="cuda"):
+                return self._forward(batch)
+        else:
+            return self._forward(batch)
 
 
 class DistributedDataClassifier(ABC):
@@ -38,31 +102,28 @@ class DistributedDataClassifier(ABC):
 
     def __init__(
         self,
-        model_file_name,
+        model,
         labels,
         filter_by,
         batch_size,
         out_dim,
         pred_column,
         max_chars,
-        num_workers,
         device_type,
         autocast,
     ):
-        self.model_file_name = model_file_name
+        self.model = model
         self.labels = labels
         self.filter_by = filter_by
         self.batch_size = batch_size
         self.out_dim = out_dim
         self.pred_column = pred_column
         self.max_chars = max_chars
-        self.num_workers = num_workers
         self.device_type = device_type
         self.autocast = autocast
 
     def __call__(self, dataset: DocumentDataset):
         result_doc_dataset = self._run_classifier(dataset)
-
         if self.filter_by is not None:
             return self._filter_documents(result_doc_dataset)
 
@@ -72,13 +133,6 @@ def __call__(self, dataset: DocumentDataset):
     def _run_classifier(self):
         pass
 
-    def _cfg_per_partition(self):
-        return load_object_on_worker(
-            "cfg_with_tokenizer",
-            self._load_cfg_with_tokenizer,
-            {},
-        )
-
     def _filter_documents(
         self,
         dataset: DocumentDataset,
@@ -96,116 +150,180 @@ def _filter_documents(
         raise TypeError("filter_by must be a string or list type")
 
 
+def _run_classifier_helper(
+    df: "dask_cudf.DataFrame",
+    model: "HFModel",
+    labels: list[str],
+    max_chars: int,
+    batch_size: int,
+    label_col: str,
+    prob_col: str = None,
+) -> "dask_cudf.DataFrame":
+
+    keep_prob = prob_col is not None
+    prob_internal_col = "_prob"
+    # TODO: Make crossfit handle this cleanly
+    pred_internal_col = "labels"
+    df["sliced_text"] = df["text"].str.slice(0, max_chars)
+    columns_to_keep_list = df.columns.to_list()
+    columns_to_keep_list.remove("sliced_text")
+
+    classifier_pipe = op.Sequential(
+        op.Tokenizer(model, cols=["sliced_text"], tokenizer_type="sentencepiece"),
+        op.Predictor(
+            model,
+            sorted_data_loader=True,
+            batch_size=batch_size,
+            pred_output_col=prob_internal_col,
+        ),
+        repartition=df.npartitions,
+        keep_cols=columns_to_keep_list,
+    )
+    df = classifier_pipe(df)
+    # TODO: Make crossfit handle this cleanly
+    # to prevent the labeler from dropping the prob_internal_col
+    # and combine it into a single step
+    labeling_pipe = op.Sequential(
+        op.Labeler(labels, cols=[prob_internal_col]),
+        keep_cols=columns_to_keep_list + [prob_internal_col],
+    )
+    df = labeling_pipe(df)
+    if keep_prob:
+        df = df.rename(
+            columns={prob_internal_col: prob_col, pred_internal_col: label_col},
+        )
+    else:
+        df = df.rename(columns={pred_internal_col: label_col})
+        df = df.drop(columns=[prob_internal_col])
+    return df
+
+
+class DomainModel(HFModel):
+    def __init__(self, config, out_dim=None, model_path=None, autocast=False):
+        self.config = config
+        self.out_dim = out_dim
+        self.model_path = model_path
+        self.autocast = autocast
+        super().__init__(self.config.model)
+
+    def load_model(self, device="cuda"):
+        model = CustomModel(
+            self.config,
+            out_dim=self.out_dim,
+            config_path=None,
+            pretrained=True,
+            autocast=self.autocast,
+        )
+        model = model.to(device)
+        if os.path.exists(self.model_path):
+            sd = torch.load(os.path.join(self.model_path), map_location="cpu")
+            sd = {k[7:] if k.startswith("module.") else k: sd[k] for k in sd.keys()}
+            if version.parse(TRANSFORMERS_VERSION) >= version.parse("4.31.0"):
+                sd.pop("model.embeddings.position_ids", None)
+            model.load_state_dict(sd, strict=True)
+        else:
+            raise ValueError(f"Model path {self.model_path} does not exist")
+        return model.eval()
+
+    def load_tokenizer(self):
+        return DebertaV2TokenizerFast.from_pretrained(self.config.model)
+
+    def load_config(self):
+        return AutoConfig.from_pretrained(self.path_or_name)
+
+
+class QualityModel(HFModel):
+    def __init__(self, config, out_dim=None, model_path=None, autocast=False):
+        self.config = config
+        self.out_dim = out_dim
+        self.model_path = model_path
+        self.autocast = autocast
+        super().__init__(self.config.model)
+
+    def load_model(self, device="cuda"):
+        model = CustomModel(
+            self.config,
+            out_dim=self.out_dim,
+            config_path=None,
+            pretrained=True,
+            autocast=self.autocast,
+        )
+        model = model.to(device)
+        if os.path.exists(self.model_path):
+            sd = torch.load(self.model_path, map_location="cpu")
+            if "model_state_dict" in sd:
+                sd = sd["model_state_dict"]
+            sd = {k[7:] if k.startswith("module.") else k: sd[k] for k in sd.keys()}
+            model.load_state_dict(sd, strict=True)
+        else:
+            raise ValueError(f"Model path {self.model_path} does not exist")
+        model.eval()
+        return model
+
+    def load_tokenizer(self):
+        return DebertaV2TokenizerFast.from_pretrained(self.config.model)
+
+    def load_config(self):
+        return AutoConfig.from_pretrained(self.path_or_name)
+
+
 class DomainClassifier(DistributedDataClassifier):
     def __init__(
         self,
-        model_file_name,
+        model_path,
         labels,
         filter_by=None,
         batch_size=256,
         out_dim=None,
-        pred_column="pred",
+        pred_column="domain_pred",
+        prob_column=None,
         max_chars=2000,
-        num_workers=0,
         device_type="cuda",
         autocast=True,
     ):
         if out_dim is None:
             out_dim = len(labels)
 
+        self.prob_column = prob_column
+
+        model = DomainModel(
+            config=DomainModelConfig,
+            out_dim=out_dim,
+            model_path=model_path,
+            autocast=autocast,
+        )
+
         super().__init__(
-            model_file_name=model_file_name,
+            model=model,
             labels=labels,
             filter_by=filter_by,
             batch_size=batch_size,
             out_dim=out_dim,
             pred_column=pred_column,
             max_chars=max_chars,
-            num_workers=num_workers,
             device_type=device_type,
             autocast=autocast,
         )
 
     def _run_classifier(self, dataset: DocumentDataset):
         print("Starting domain classifier inference", flush=True)
-
         df = dataset.df
-
-        meta_df = df._meta.copy()
-        meta_df[self.pred_column] = [0] * len(meta_df)
-
-        df = df.map_partitions(
-            self._inference_per_partition,
-            meta=meta_df,
-            enforce_metadata=False,
-        )
-
-        return DocumentDataset(df)
-
-    def _inference_per_partition(self, df):
-        cfg = self._cfg_per_partition()
-
-        dataset_valid = TestDataset(cfg, df, self.max_chars)
-        loader_valid = torch.utils.data.DataLoader(
-            dataset_valid,
+        df = _run_classifier_helper(
+            df=df,
+            model=self.model,
+            labels=self.labels,
+            max_chars=self.max_chars,
             batch_size=self.batch_size,
-            shuffle=False,
-            num_workers=self.num_workers,
-        )
-
-        device = torch.device(self.device_type)
-        load_model_kwargs = {"cfg": cfg, "device": device}
-
-        preds = process_all_batches(
-            loader_valid,
-            self._load_model,
-            load_model_kwargs,
-            self._run_inference,
-            {},
+            label_col=self.pred_column,
+            prob_col=self.prob_column,
         )
-        preds = preds.cpu().numpy()
-        df[self.pred_column] = [self.labels[i] for i in preds]
-
-        return df
-
-    def _load_cfg_with_tokenizer(self):
-        cfg = CFG()
-        tokenizer = DebertaV2TokenizerFast.from_pretrained(cfg.model)
-        cfg.tokenizer = tokenizer
-        return cfg
-
-    def _load_model(self, cfg, device):
-        model = CustomModel(
-            cfg, out_dim=self.out_dim, config_path=None, pretrained=True
-        )
-        model = model.to(device)
-        sd = torch.load(os.path.join(self.model_file_name), map_location="cpu")
-        sd = {k[7:] if k.startswith("module.") else k: sd[k] for k in sd.keys()}
-        if version.parse(TRANSFORMERS_VERSION) >= version.parse("4.31.0"):
-            sd.pop("model.embeddings.position_ids", None)
-
-        model.load_state_dict(sd, strict=True)
-        model.eval()
-        return model
-
-    def _run_inference(self, batch, model):
-        with torch.no_grad():
-            batch = collate(batch)
-            if self.autocast:
-                with torch.autocast(device_type=self.device_type):
-                    out = model(batch)[:, 0, :]
-            else:
-                out = model(batch)[:, 0, :]
-            pred_idx = torch.sigmoid(out).argmax(1)
-
-        return pred_idx
+        return DocumentDataset(df)
 
 
-# TODO: Implement MultipleModelQualityClassifier class
 class QualityClassifier(DistributedDataClassifier):
     def __init__(
         self,
-        model_file_name,
+        model_path,
         labels,
         filter_by=None,
         batch_size=256,
@@ -213,121 +331,46 @@ def __init__(
         pred_column="quality_pred",
         prob_column="quality_prob",
         max_chars=6000,
-        num_workers=0,
         device_type="cuda",
         autocast=True,
-        max_len=1024,
     ):
-        # Binary case
         if len(labels) == 2:
-            out_dim = 1
-            self.binary_classification = True
+            out_dim = 1  # Binary classification
         else:
             if out_dim is None:
-                out_dim = len(labels)
-            self.binary_classification = False
+                out_dim = len(labels)  # Multiclass classification
 
         self.prob_column = prob_column
-        self.max_len = max_len
+
+        model = QualityModel(
+            config=QualityModelConfig,
+            out_dim=out_dim,
+            model_path=model_path,
+            autocast=autocast,
+        )
 
         super().__init__(
-            model_file_name=model_file_name,
+            model=model,
             labels=labels,
             filter_by=filter_by,
             batch_size=batch_size,
             out_dim=out_dim,
             pred_column=pred_column,
             max_chars=max_chars,
-            num_workers=num_workers,
             device_type=device_type,
             autocast=autocast,
         )
 
     def _run_classifier(self, dataset: DocumentDataset):
-        print("Starting quality classifier inference", flush=True)
-
+        print("Starting Quality classifier inference", flush=True)
         df = dataset.df
-
-        meta_df = df._meta.copy()
-        meta_df[self.pred_column] = ["low"] * len(meta_df)
-        meta_df[self.prob_column] = [[0, 0, 1]] * len(meta_df)
-
-        df = df.map_partitions(
-            self._inference_per_partition,
-            meta=meta_df,
-            enforce_metadata=False,
-        )
-
-        return DocumentDataset(df)
-
-    def _inference_per_partition(self, df):
-        cfg = self._cfg_per_partition()
-
-        dataset_valid = TestDataset(cfg, df, self.max_chars)
-        loader_valid = torch.utils.data.DataLoader(
-            dataset_valid,
+        df = _run_classifier_helper(
+            df=df,
+            model=self.model,
+            labels=self.labels,
+            max_chars=self.max_chars,
             batch_size=self.batch_size,
-            shuffle=False,
-            num_workers=self.num_workers,
-        )
-        device = torch.device(self.device_type)
-        if len(self.labels) == 1:
-            raise ValueError("Labels must be more than 1")
-
-        load_model_kwargs = {
-            "cfg": cfg,
-            "device": device,
-        }
-
-        probs = process_all_batches(
-            loader_valid,
-            self._load_model,
-            load_model_kwargs,
-            self._run_inference,
-            {},
+            label_col=self.pred_column,
+            prob_col=self.prob_column,
         )
-
-        if self.binary_classification:
-            preds = (probs > 0.5).to(torch.int64).squeeze()
-        else:
-            preds = torch.argmax(probs, dim=1)
-
-        df[self.pred_column] = [
-            self.labels[i] for i in preds.to("cpu").numpy().tolist()
-        ]
-        df[self.prob_column] = probs.to("cpu").numpy().tolist()
-
-        return df
-
-    def _load_cfg_with_tokenizer(self):
-        cfg = CFG(max_len=self.max_len)
-        tokenizer = DebertaV2TokenizerFast.from_pretrained(cfg.model)
-        cfg.tokenizer = tokenizer
-        return cfg
-
-    def _load_model(self, cfg, device):
-        model = CustomModel(
-            cfg, out_dim=self.out_dim, config_path=None, pretrained=True
-        )
-        model = model.to(device)
-        sd = torch.load(self.model_file_name, map_location="cpu")
-        if "model_state_dict" in sd:
-            sd = sd["model_state_dict"]
-        sd = {k[7:] if k.startswith("module.") else k: sd[k] for k in sd.keys()}
-        model.load_state_dict(sd, strict=True)
-        model.eval()
-        return model
-
-    def _run_inference(self, batch, model):
-        with torch.no_grad():
-            batch = collate(batch)
-            if self.autocast:
-                with torch.autocast(device_type=self.device_type):
-                    out = model(batch)[:, 0, :]
-            else:
-                out = model(batch)[:, 0, :]
-            if self.binary_classification:
-                probs = torch.sigmoid(out)
-            else:
-                probs = torch.softmax(out, dim=1)
-        return probs
+        return DocumentDataset(df)
diff --git a/nemo_curator/scripts/domain_classifier_inference.py b/nemo_curator/scripts/domain_classifier_inference.py
new file mode 100644
index 000000000..9738c2d2b
--- /dev/null
+++ b/nemo_curator/scripts/domain_classifier_inference.py
@@ -0,0 +1,132 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+import warnings
+
+os.environ["RAPIDS_NO_INITIALIZE"] = "1"
+from nemo_curator import DomainClassifier
+from nemo_curator.datasets import DocumentDataset
+
+# Get relevant args
+from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk
+from nemo_curator.utils.file_utils import get_remaining_files
+from nemo_curator.utils.script_utils import parse_distributed_classifier_args
+
+warnings.filterwarnings("ignore")
+
+
+def main():
+    labels = [
+        "Adult",
+        "Arts_and_Entertainment",
+        "Autos_and_Vehicles",
+        "Beauty_and_Fitness",
+        "Books_and_Literature",
+        "Business_and_Industrial",
+        "Computers_and_Electronics",
+        "Finance",
+        "Food_and_Drink",
+        "Games",
+        "Health",
+        "Hobbies_and_Leisure",
+        "Home_and_Garden",
+        "Internet_and_Telecom",
+        "Jobs_and_Education",
+        "Law_and_Government",
+        "News",
+        "Online_Communities",
+        "People_and_Society",
+        "Pets_and_Animals",
+        "Real_Estate",
+        "Science",
+        "Sensitive_Subjects",
+        "Shopping",
+        "Sports",
+        "Travel_and_Transportation",
+    ]
+
+    args = parse_distributed_classifier_args().parse_args()
+    print(f"Arguments parsed = {args}", flush=True)
+    max_chars = 2000
+
+    client = get_client(args, cluster_type="gpu")
+    print("Starting domain classifier inference", flush=True)
+    global_st = time.time()
+    files_per_run = len(client.scheduler_info()["workers"]) * 2
+
+    if not os.path.exists(args.output_data_dir):
+        os.makedirs(args.output_data_dir)
+
+    input_files = get_remaining_files(
+        args.input_data_dir, args.output_data_dir, args.input_file_type
+    )
+    print(f"Total input files {len(input_files)}", flush=True)
+
+    if args.input_file_type == "pickle":
+        add_filename = False
+    else:
+        add_filename = True
+
+    domain_classifier = DomainClassifier(
+        model_path=args.model_path,
+        labels=labels,
+        max_chars=max_chars,
+        batch_size=args.batch_size,
+        out_dim=len(labels),
+        autocast=args.autocast,
+    )
+
+    for file_batch_id, i in enumerate(range(0, len(input_files), files_per_run)):
+        batch_st = time.time()
+        current_batch_files = input_files[i : i + files_per_run]
+        print(
+            f"File Batch ID {file_batch_id}: total input files {len(current_batch_files)}",
+            flush=True,
+        )
+        df = read_data(
+            input_files=current_batch_files,
+            file_type=args.input_file_type,
+            add_filename=add_filename,
+        )
+        df = domain_classifier(DocumentDataset(df)).df
+        print(f"Total input Dask DataFrame partitions {df.npartitions}", flush=True)
+
+        write_to_disk(
+            df=df,
+            output_file_dir=args.output_data_dir,
+            write_to_filename=add_filename,
+            output_type=args.output_file_type,
+        )
+        batch_et = time.time()
+        print(
+            f"File Batch ID {file_batch_id}: completed in {batch_et-batch_st} seconds",
+            flush=True,
+        )
+
+    global_et = time.time()
+    print(
+        f"Total time taken for domain classifier inference: {global_et-global_st} s",
+        flush=True,
+    )
+    client.close()
+
+
+def console_script():
+    main()
+
+
+if __name__ == "__main__":
+    console_script()
diff --git a/nemo_curator/scripts/quality_classifier_inference.py b/nemo_curator/scripts/quality_classifier_inference.py
new file mode 100644
index 000000000..c7853394b
--- /dev/null
+++ b/nemo_curator/scripts/quality_classifier_inference.py
@@ -0,0 +1,134 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+import warnings
+
+os.environ["RAPIDS_NO_INITIALIZE"] = "1"
+from nemo_curator import QualityClassifier
+from nemo_curator.datasets import DocumentDataset
+from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk
+from nemo_curator.utils.file_utils import get_remaining_files
+from nemo_curator.utils.script_utils import parse_distributed_classifier_args
+
+warnings.filterwarnings("ignore")
+
+
+def add_quality_model_specific_args(parser):
+    """
+    This function adds a command line argument for the number of labels.
+
+    Args:
+        parser: An argparse ArgumentParser object.
+    Returns:
+        An argparse ArgumentParser with 1 additional argument.
+
+    """
+    parser.add_argument("--num-labels", type=int, default=3)
+    return parser
+
+
+def get_labels(num_labels):
+    """
+    This function returns a list of quality labels, depending on how many labels the user expects.
+
+    Args:
+        num_labels: An integer representing the number of possible classification labels.
+    Returns:
+        A list of label names.
+
+    """
+    if num_labels == 3:
+        labels = ["High", "Medium", "Low"]
+    elif num_labels == 2:
+        labels = ["Medium_High", "Low"]
+    return labels
+
+
+def main():
+    parser = parse_distributed_classifier_args()
+    parser = add_quality_model_specific_args(parser)
+    args = parser.parse_args()
+    labels = get_labels(args.num_labels)
+    print(f"Arguments parsed = {args}", flush=True)
+    max_chars = 6000
+
+    client = get_client(args, cluster_type="gpu")
+    print("Starting quality classifier inference", flush=True)
+    global_st = time.time()
+    files_per_run = len(client.scheduler_info()["workers"]) * 2
+
+    if not os.path.exists(args.output_data_dir):
+        os.makedirs(args.output_data_dir)
+
+    input_files = get_remaining_files(
+        args.input_data_dir, args.output_data_dir, args.input_file_type
+    )
+    print(f"Total input files {len(input_files)}", flush=True)
+
+    if args.input_file_type == "pickle":
+        add_filename = False
+    else:
+        add_filename = True
+
+    classifier = QualityClassifier(
+        model_path=args.model_path,
+        max_chars=max_chars,
+        labels=labels,
+        batch_size=args.batch_size,
+        autocast=args.autocast,
+        out_dim=len(labels),
+    )
+
+    for file_batch_id, i in enumerate(range(0, len(input_files), files_per_run)):
+        batch_st = time.time()
+        current_batch_files = input_files[i : i + files_per_run]
+        print(
+            f"File Batch ID {file_batch_id}: total input files {len(current_batch_files)}",
+            flush=True,
+        )
+        df = read_data(
+            input_files=current_batch_files,
+            file_type=args.input_file_type,
+            add_filename=add_filename,
+        )
+        print(f"Total input Dask DataFrame partitions {df.npartitions}", flush=True)
+        df = classifier(DocumentDataset(df)).df
+        write_to_disk(
+            df=df,
+            output_file_dir=args.output_data_dir,
+            write_to_filename=add_filename,
+            output_type=args.output_file_type,
+        )
+        batch_et = time.time()
+        print(
+            f"File Batch ID {file_batch_id}: completed in {batch_et-batch_st} seconds",
+            flush=True,
+        )
+
+    global_et = time.time()
+    print(
+        f"Total time taken for quality classifier inference: {global_et-global_st} s",
+        flush=True,
+    )
+    client.close()
+
+
+def console_script():
+    main()
+
+
+if __name__ == "__main__":
+    console_script()
diff --git a/nemo_curator/distributed_data_classification/verify_results.py b/nemo_curator/scripts/verify_classification_results.py
similarity index 100%
rename from nemo_curator/distributed_data_classification/verify_results.py
rename to nemo_curator/scripts/verify_classification_results.py
diff --git a/nemo_curator/utils/script_utils.py b/nemo_curator/utils/script_utils.py
index e2811dd1e..582b85c21 100644
--- a/nemo_curator/utils/script_utils.py
+++ b/nemo_curator/utils/script_utils.py
@@ -165,6 +165,82 @@ def parse_gpu_dedup_args(
     return parser
 
 
+def parse_distributed_classifier_args(
+    description="Default distributed classifier argument parser",
+) -> argparse.ArgumentParser:
+    """
+    Adds default set of arguments that are common to multiple stages
+    of the pipeline
+    """
+
+    parser = argparse.ArgumentParser(
+        description,
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser = add_distributed_args(parser)
+    # Set low default RMM pool size for classifier
+    # to allow pytorch to grow its memory usage
+    # by default
+    parser.set_defaults(rmm_pool_size="512MB")
+    parser.add_argument(
+        "--input-data-dir",
+        type=str,
+        help="The path of the input files",
+        required=True,
+    )
+    parser.add_argument(
+        "--output-data-dir",
+        type=str,
+        help="The path of the output files",
+        required=True,
+    )
+    parser.add_argument(
+        "--model-path",
+        type=str,
+        help="The path to the model file",
+        required=True,
+    )
+    parser.add_argument(
+        "--input-file-type",
+        type=str,
+        help="The type of the input files",
+        required=True,
+    )
+    parser.add_argument(
+        "--output-file-type",
+        type=str,
+        default="jsonl",
+        help="The type of the output files",
+        required=False,
+    )
+    parser.add_argument(
+        "--batch-size",
+        type=int,
+        default=128,
+        help="The batch size to be used for inference",
+    )
+    attach_bool_arg(
+        parser, "autocast", default=True, help_str="Whether to use autocast or not"
+    )
+    attach_bool_arg(
+        parser,
+        "enable-spilling",
+        default=True,
+        help_str="Whether to enable spilling or not",
+    )
+
+    # Setting to False makes it more stable for long running jobs
+    # possibly because of memory fragmentation
+    attach_bool_arg(
+        parser,
+        "set-torch-to-use-rmm",
+        default=False,
+        help_str="Whether to set torch to use RMM or not",
+    )
+
+    return parser
+
+
 def chunk_list(lst, nchnks):
     nitem = len(lst)
     splits = splitnum(nitem, nchnks)
diff --git a/setup.py b/setup.py
index 357e33e51..23f3c58e0 100644
--- a/setup.py
+++ b/setup.py
@@ -60,6 +60,7 @@
         "presidio-anonymizer==2.2.351",
         "usaddress==0.5.10",
         "nemo_toolkit[nlp]>=1.23.0",
+        "crossfit @ git+https://github.com/rapidsai/crossfit.git@1ee3de4",
         # justext installation breaks without lxml[html_clean]
         # due to this: https://github.com/miso-belica/jusText/issues/47
         "lxml[html_clean]",
@@ -68,6 +69,7 @@
         "cuda12x": [
             "cudf-cu12>=24.2",
             "dask-cudf-cu12>=24.2",
+            "cuml-cu12>=24.2",
             "cugraph-cu12>=24.2",
             "dask-cuda>=24.2",
             "spacy[cuda12x]>=3.6.0, <4.0.0",
@@ -97,11 +99,9 @@
             "gpu_connected_component=nemo_curator.scripts.fuzzy_deduplication.connected_components:console_script",
             "gpu_exact_dups=nemo_curator.scripts.find_exact_duplicates:console_script",
             "deidentify=nemo_curator.scripts.find_pii_and_deidentify:console_script",
-            "generate_statistics=nemo_curator.distributed_data_classification.generate_statistics:console_script",
-            "domain_classifier_inference=nemo_curator.distributed_data_classification.domain_classifier_inference:console_script",
-            "quality_classifier_multiple_models_inference=nemo_curator.distributed_data_classification.quality_classifier_multiple_models_inference:console_script",
-            "quality_classifier_inference=nemo_curator.distributed_data_classification.quality_classifier_inference:console_script",
-            "verify_results=nemo_curator.distributed_data_classification.verify_results:console_script",
+            "domain_classifier_inference=nemo_curator.scripts.domain_classifier_inference:console_script",
+            "quality_classifier_inference=nemo_curator.scripts.quality_classifier_inference:console_script",
+            "verify_classification_results=nemo_curator.scripts.verify_classification_results:console_script",
             "blend_datasets=nemo_curator.scripts.blend_datasets:console_script",
         ],
     },
diff --git a/tutorials/distributed_data_classification/distributed_data_classification.ipynb b/tutorials/distributed_data_classification/distributed_data_classification.ipynb
new file mode 100644
index 000000000..b0fec862c
--- /dev/null
+++ b/tutorials/distributed_data_classification/distributed_data_classification.ipynb
@@ -0,0 +1,376 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Distributed Data Classification with Quality and Domain Classifiers\n",
+    "\n",
+    "The notebook demonstrates the use of two classifiers for distributed data classification, including quality and domain classifiers. The quality classifier is used to classify the quality of the data, while the domain classifier is used to classify the domain of the data. These classifers help with annotation which helps data blending for foundation model training. \n",
+    "\n",
+    "The classifiers are accelerated using CrossFit,(https://github.com/rapidsai/crossfit), a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "env: PYTHONWARNINGS=ignore\n"
+     ]
+    }
+   ],
+   "source": [
+    "#### Silence Warnings (HuggingFace internal warnings)\n",
+    "\n",
+    "%env PYTHONWARNINGS=ignore\n",
+    "import warnings\n",
+    "warnings.filterwarnings(\"ignore\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from dask_cuda import LocalCUDACluster\n",
+    "from dask.distributed import Client\n",
+    "from nemo_curator import DomainClassifier, QualityClassifier\n",
+    "from nemo_curator.datasets import DocumentDataset"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "cluster = LocalCUDACluster(rmm_async=True, rmm_pool_size=\"1GB\")\n",
+    "client = Client(cluster)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Define the data file paths "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "input_file_path=\"/input_data_dir/\"\n",
+    "output_file_path = \"output_data_dir/\"\n",
+    "domain_model_path = \"domain_model.pth\"\n",
+    "quality_model_path = \"quality_model.pth\""
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Create a Classifier"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "classifier_type=\"DomainClassifier\" # or \"QualityClassifier\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Reading 16 files\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "CPU times: user 10.5 s, sys: 5.33 s, total: 15.8 s\n",
+      "Wall time: 11.4 s\n"
+     ]
+    }
+   ],
+   "source": [
+    "%%time\n",
+    "\n",
+    "input_dataset = DocumentDataset.read_json(\n",
+    "    input_file_path, backend=\"cudf\", add_filename=True\n",
+    ")\n",
+    "\n",
+    "if classifier_type == \"DomainClassifier\":\n",
+    "    domain_labels = [\n",
+    "    \"Adult\",\n",
+    "    \"Arts_and_Entertainment\",\n",
+    "    \"Autos_and_Vehicles\",\n",
+    "    \"Beauty_and_Fitness\",\n",
+    "    \"Books_and_Literature\",\n",
+    "    \"Business_and_Industrial\",\n",
+    "    \"Computers_and_Electronics\",\n",
+    "    \"Finance\",\n",
+    "    \"Food_and_Drink\",\n",
+    "    \"Games\",\n",
+    "    \"Health\",\n",
+    "    \"Hobbies_and_Leisure\",\n",
+    "    \"Home_and_Garden\",\n",
+    "    \"Internet_and_Telecom\",\n",
+    "    \"Jobs_and_Education\",\n",
+    "    \"Law_and_Government\",\n",
+    "    \"News\",\n",
+    "    \"Online_Communities\",\n",
+    "    \"People_and_Society\",\n",
+    "    \"Pets_and_Animals\",\n",
+    "    \"Real_Estate\",\n",
+    "    \"Science\",\n",
+    "    \"Sensitive_Subjects\",\n",
+    "    \"Shopping\",\n",
+    "    \"Sports\",\n",
+    "    \"Travel_and_Transportation\",\n",
+    "    ]\n",
+    "    classifier = DomainClassifier(\n",
+    "        model_path=domain_model_path,\n",
+    "        labels=domain_labels,\n",
+    "        batch_size=1024,\n",
+    "    )\n",
+    "elif classifier_type == \"QualityClassifier\":\n",
+    "    quality_labels = [\"High\", \"Medium\", \"Low\"]\n",
+    "    model_file_name = \"quality_classifier.pth\"\n",
+    "    classifier = QualityClassifier(\n",
+    "        model_path=quality_model_path,\n",
+    "        labels=quality_labels,\n",
+    "        batch_size=1024,\n",
+    "    )\n",
+    "else:\n",
+    "    raise ValueError(\"Invalid classifier type\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Run the  Classifier\n",
+    "\n",
+    "Dask operations are lazy, so the the classifier will not run until we call a eager operation like `to_json`, `compute` or `persist`. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Starting domain classifier inference\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "GPU: 0, Part: 1: 100%|██████████| 938/938 [00:09<00:00, 101.99it/s] \n",
+      "GPU: 0, Part: 3: 100%|██████████| 938/938 [00:10<00:00, 92.36it/s] ]\n",
+      "GPU: 0, Part: 0: 100%|██████████| 938/938 [00:10<00:00, 91.25it/s] ]\n",
+      "GPU: 0, Part: 5: 100%|██████████| 938/938 [00:10<00:00, 88.82it/s] \n",
+      "GPU: 0, Part: 14: 100%|██████████| 937/937 [00:10<00:00, 88.11it/s] \n",
+      "GPU: 0, Part: 8: 100%|██████████| 937/937 [00:10<00:00, 85.46it/s] ]\n",
+      "GPU: 0, Part: 9: 100%|██████████| 937/937 [00:10<00:00, 86.16it/s] \n",
+      "GPU: 0, Part: 4: 100%|██████████| 938/938 [00:10<00:00, 85.65it/s]]\n",
+      "GPU: 0, Part: 11: 100%|██████████| 937/937 [00:11<00:00, 83.73it/s] \n",
+      "GPU: 0, Part: 6: 100%|██████████| 938/938 [00:11<00:00, 83.62it/s]\n",
+      "GPU: 0, Part: 10: 100%|██████████| 937/937 [00:11<00:00, 81.27it/s] \n",
+      "GPU: 0, Part: 2: 100%|██████████| 938/938 [00:12<00:00, 72.59it/s]]\n",
+      "GPU: 0, Part: 7: 100%|██████████| 937/937 [00:13<00:00, 71.75it/s]\n",
+      "GPU: 0, Part: 12: 100%|██████████| 937/937 [00:13<00:00, 69.12it/s]\n",
+      "GPU: 0, Part: 15: 100%|██████████| 937/937 [00:13<00:00, 68.47it/s]\n",
+      "GPU: 0, Part: 13: 100%|██████████| 937/937 [00:14<00:00, 66.29it/s]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Writing to disk complete for 16 partitions\n",
+      "CPU times: user 2.34 s, sys: 2.24 s, total: 4.58 s\n",
+      "Wall time: 17.2 s\n"
+     ]
+    }
+   ],
+   "source": [
+    "%%time\n",
+    "\n",
+    "result_dataset = classifier(dataset=input_dataset)\n",
+    "result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Inspect the Output"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Reading 16 files\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>adlr_id</th>\n",
+       "      <th>domain_pred</th>\n",
+       "      <th>filename</th>\n",
+       "      <th>id</th>\n",
+       "      <th>pred</th>\n",
+       "      <th>source_id</th>\n",
+       "      <th>split_id</th>\n",
+       "      <th>text</th>\n",
+       "      <th>url</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>cc-2022-40-0431053204</td>\n",
+       "      <td>Online_Communities</td>\n",
+       "      <td>00.jsonl</td>\n",
+       "      <td>a8083fe4-525d-4888-8513-b91f43bd8ee1</td>\n",
+       "      <td>Online_Communities</td>\n",
+       "      <td>crawl-data-CC-MAIN-2022-40-segments-1664030336...</td>\n",
+       "      <td>lambada-0003225258-0000</td>\n",
+       "      <td>Having been a community leader—and member—for ...</td>\n",
+       "      <td>https://lisalarter.com/7-tips-for-building-ste...</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>cc-2022-40-0510168267</td>\n",
+       "      <td>Finance</td>\n",
+       "      <td>00.jsonl</td>\n",
+       "      <td>559febdc-cb7f-4217-897a-c8dac325123b</td>\n",
+       "      <td>Finance</td>\n",
+       "      <td>crawl-data-CC-MAIN-2022-40-segments-1664030337...</td>\n",
+       "      <td>lambada-0003918122-0000</td>\n",
+       "      <td>Zelle is a way of sending money to almost anyo...</td>\n",
+       "      <td>https://oregonmassageandwellnessclinic.com/app...</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "                 adlr_id         domain_pred  filename  \\\n",
+       "0  cc-2022-40-0431053204  Online_Communities  00.jsonl   \n",
+       "1  cc-2022-40-0510168267             Finance  00.jsonl   \n",
+       "\n",
+       "                                     id                pred  \\\n",
+       "0  a8083fe4-525d-4888-8513-b91f43bd8ee1  Online_Communities   \n",
+       "1  559febdc-cb7f-4217-897a-c8dac325123b             Finance   \n",
+       "\n",
+       "                                           source_id                 split_id  \\\n",
+       "0  crawl-data-CC-MAIN-2022-40-segments-1664030336...  lambada-0003225258-0000   \n",
+       "1  crawl-data-CC-MAIN-2022-40-segments-1664030337...  lambada-0003918122-0000   \n",
+       "\n",
+       "                                                text  \\\n",
+       "0  Having been a community leader—and member—for ...   \n",
+       "1  Zelle is a way of sending money to almost anyo...   \n",
+       "\n",
+       "                                                 url  \n",
+       "0  https://lisalarter.com/7-tips-for-building-ste...  \n",
+       "1  https://oregonmassageandwellnessclinic.com/app...  "
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "output_dataset = DocumentDataset.read_json(output_file_path, backend=\"cudf\", add_filename=True)\n",
+    "output_dataset.df.head(2)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "##### Cleanup the output file"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!rm -rf $output_file_path"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "NeMo-Curator-env-2",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.14"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}