Skip to content

Commit

Permalink
Multilingual Domain Classifier (#363)
Browse files Browse the repository at this point in the history
* initial commit

Signed-off-by: Sarah Yurick <[email protected]>

* run black

Signed-off-by: Sarah Yurick <[email protected]>

* combine with DomainClassifier

Signed-off-by: Sarah Yurick <[email protected]>

* isort

Signed-off-by: Sarah Yurick <[email protected]>

* add links

Signed-off-by: Sarah Yurick <[email protected]>

* add praateek's suggestion

Signed-off-by: Sarah Yurick <[email protected]>

* add ryan's suggestion

Signed-off-by: Sarah Yurick <[email protected]>

* update readmes

Signed-off-by: Sarah Yurick <[email protected]>

* create MultilingualDomainClassifier

Signed-off-by: Sarah Yurick <[email protected]>

* add api

Signed-off-by: Sarah Yurick <[email protected]>

---------

Signed-off-by: Sarah Yurick <[email protected]>
  • Loading branch information
sarahyurick authored Dec 3, 2024
1 parent edd6262 commit 7272ca0
Show file tree
Hide file tree
Showing 11 changed files with 370 additions and 34 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ All of our text pipelines have great multilingual support.
- [Heuristic Filtering](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html)
- Classifier Filtering
- [fastText](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html)
- GPU-Accelerated models: [Domain, Quality, and Safety Classification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html)
- GPU-Accelerated models: [Domain (English and multilingual), Quality, and Safety Classification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html)
- **GPU-Accelerated Deduplication**
- [Exact Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html)
- [Fuzzy Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) via MinHash Locality Sensitive Hashing
Expand Down
5 changes: 4 additions & 1 deletion docs/user-guide/api/classifiers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ Classifiers
.. autoclass:: nemo_curator.classifiers.DomainClassifier
:members:

.. autoclass:: nemo_curator.classifiers.MultilingualDomainClassifier
:members:

.. autoclass:: nemo_curator.classifiers.QualityClassifier
:members:

.. autoclass:: nemo_curator.classifiers.FineWebEduClassifier
:members:

.. autoclass:: nemo_curator.classifiers.AegisClassifier
:members:
:members:
2 changes: 1 addition & 1 deletion docs/user-guide/cpuvsgpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ The following NeMo Curator modules are GPU based.
* Semantic Deduplication
* Distributed Data Classification

* Domain Classification
* Domain Classification (English and multilingual)
* Quality Classification

GPU modules store the ``DocumentDataset`` using a ``cudf`` backend instead of a ``pandas`` one.
Expand Down
29 changes: 27 additions & 2 deletions docs/user-guide/distributeddataclassification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ NeMo Curator provides a module to help users run inference with pre-trained mode
This is achieved by chunking the datasets across multiple computing nodes, each equipped with multiple GPUs, to accelerate the classification task in a distributed manner.
Since the classification of a single text document is independent of other documents within the dataset, we can distribute the workload across multiple nodes and GPUs to perform parallel processing.

Domain, quality, content safety, and educational content models are tasks we include as examples within our module.
Domain (English and multilingual), quality, content safety, and educational content models are tasks we include as examples within our module.

Here, we summarize why each is useful for training an LLM:

- The **Domain Classifier** is useful because it helps the LLM understand the context and specific domain of the input text. Because different domains have different linguistic characteristics and terminologies, an LLM's ability to generate contextually relevant responses can be improved by tailoring training data to a specific domain. Overall, this helps provide more accurate and specialized information.

- The **Multilingual Domain Classifier** is the same as the domain classifier, but has been trained to classify text in 52 languages, including English.

- The **Quality Classifier** is useful for filtering out noisy or low quality data. This allows the model to focus on learning from high quality and informative examples, which contributes to the LLM's robustness and enhances its ability to generate reliable and meaningful outputs. Additionally, quality classification helps mitigate biases and inaccuracies that may arise from poorly curated training data.

- The **AEGIS Safety Models** are essential for filtering harmful or risky content, which is critical for training models that should avoid learning from unsafe data. By classifying content into 13 critical risk categories, AEGIS helps remove harmful or inappropriate data from the training sets, improving the overall ethical and safety standards of the LLM.
Expand All @@ -45,7 +47,7 @@ Check out ``nemo_curator.classifiers.base.py`` for reference.
Domain Classifier
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The Domain Classifier is used to categorize text documents into specific domains or subject areas. This is particularly useful for organizing large datasets and tailoring the training data for domain-specific LLMs.
The Domain Classifier is used to categorize English text documents into specific domains or subject areas. This is particularly useful for organizing large datasets and tailoring the training data for domain-specific LLMs.

Let's see how ``DomainClassifier`` works in a small excerpt taken from ``examples/classifiers/domain_example.py``:

Expand All @@ -64,6 +66,29 @@ Let's see how ``DomainClassifier`` works in a small excerpt taken from ``example
In this example, the domain classifier is obtained directly from `Hugging Face <https://huggingface.co/nvidia/domain-classifier>`_.
It filters the input dataset to include only documents classified as "Games" or "Sports".

Multilingual Domain Classifier
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The Multilingual Domain Classifier is used to categorize text documents across 52 languages into specific domains or subject areas.

Using the ``MultilingualDomainClassifier`` is very similar to using the ``DomainClassifier`` as described above. Here is an example:

.. code-block:: python
from nemo_curator.classifiers import MultilingualDomainClassifier
files = get_all_files_paths_under("japanese_books_dataset/")
input_dataset = DocumentDataset.read_json(files, backend="cudf")
multilingual_domain_classifier = MultilingualDomainClassifier(
filter_by=["Games", "Sports"],
)
result_dataset = multilingual_domain_classifier(dataset=input_dataset)
result_dataset.to_json("games_and_sports/")
For more information about the multilingual domain classifier, including its supported languages, please see the `nvidia/multilingual-domain-classifier <https://huggingface.co/nvidia/multilingual-domain-classifier>`_ on Hugging Face.

Quality Classifier
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
3 changes: 2 additions & 1 deletion examples/classifiers/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
## Text Classification

The Python scripts in this directory demonstrate how to run classification on your text data with each of these 4 classifiers:
The Python scripts in this directory demonstrate how to run classification on your text data with each of these 5 classifiers:

- Domain Classifier
- Multilingual Domain Classifier
- Quality Classifier
- AEGIS Safety Models
- FineWeb Educational Content Classifier
Expand Down
67 changes: 67 additions & 0 deletions examples/classifiers/multilingual_domain_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import time

from nemo_curator.classifiers import MultilingualDomainClassifier
from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.script_utils import ArgumentHelper


def main(args):
global_st = time.time()

# Input can be a string or list
input_file_path = "/path/to/data"
output_file_path = "./"

client_args = ArgumentHelper.parse_client_args(args)
client_args["cluster_type"] = "gpu"
client = get_client(**client_args)

input_dataset = DocumentDataset.read_json(
input_file_path, backend="cudf", add_filename=True
)

multilingual_domain_classifier = MultilingualDomainClassifier(
filter_by=["Games", "Sports"]
)
result_dataset = multilingual_domain_classifier(dataset=input_dataset)

result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)

global_et = time.time()
print(
f"Total time taken for multilingual domain classifier inference: {global_et-global_st} s",
flush=True,
)

client.close()


def attach_args(
parser=argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
),
):
argumentHelper = ArgumentHelper(parser)
argumentHelper.add_distributed_classifier_cluster_args()

return argumentHelper.parser


if __name__ == "__main__":
main(attach_args().parse_args())
3 changes: 2 additions & 1 deletion nemo_curator/classifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

os.environ["RAPIDS_NO_INITIALIZE"] = "1"
from .aegis import AegisClassifier, InstructionDataGuardClassifier
from .domain import DomainClassifier
from .domain import DomainClassifier, MultilingualDomainClassifier
from .fineweb_edu import FineWebEduClassifier
from .quality import QualityClassifier

__all__ = [
"DomainClassifier",
"MultilingualDomainClassifier",
"QualityClassifier",
"AegisClassifier",
"InstructionDataGuardClassifier",
Expand Down
154 changes: 128 additions & 26 deletions nemo_curator/classifiers/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@
from nemo_curator.datasets import DocumentDataset

DOMAIN_IDENTIFIER = "nvidia/domain-classifier"
DOMAIN_BASE_MODEL = "microsoft/deberta-v3-base"
MULTILINGUAL_DOMAIN_IDENTIFIER = "nvidia/multilingual-domain-classifier"
MULTILINGUAL_DOMAIN_BASE_MODEL = "microsoft/mdeberta-v3-base"


@dataclass
class DomainModelConfig:
model: str = "microsoft/deberta-v3-base"
identifier: str = DOMAIN_IDENTIFIER
base_model: str = DOMAIN_BASE_MODEL
fc_dropout: float = 0.2
max_len: int = 512

Expand All @@ -49,44 +53,30 @@ def __init__(
if max_mem_gb is None:
max_mem_gb = _get_suggest_memory_for_classifier()

super().__init__(self.config.model, max_mem_gb=max_mem_gb)
super().__init__(self.config.base_model, max_mem_gb=max_mem_gb)

def load_model(self, device: str = "cuda"):
model = HFDeberta.from_pretrained(DOMAIN_IDENTIFIER)
model = HFDeberta.from_pretrained(self.config.identifier)
model.set_autocast(self.autocast)
model = model.to(device)
return model.eval()

def load_tokenizer(self):
return AutoTokenizer.from_pretrained(DOMAIN_IDENTIFIER)
return AutoTokenizer.from_pretrained(self.config.identifier)

def load_config(self):
return AutoConfig.from_pretrained(DOMAIN_IDENTIFIER)
return AutoConfig.from_pretrained(self.config.identifier)


class DomainClassifier(DistributedDataClassifier):
class _DomainClassifier(DistributedDataClassifier):
"""
DomainClassifier is a specialized classifier designed for domain classification tasks, utilizing the
NVIDIA Domain Classifier model (https://huggingface.co/nvidia/domain-classifier). This class is optimized
for running on multi-node, multi-GPU setups to enable fast and efficient inference on large datasets.
Attributes:
filter_by (list[str], optional): The classes to filter the dataset by.
If None, all classes will be included. Defaults to None.
batch_size (int): The number of samples per batch for inference. Defaults to 256.
text_field (str): The field in the dataset that should be classified.
pred_column (str): The column name where predictions will be stored. Defaults to "domain_pred".
prob_column (str, optional): The column name where prediction probabilities will be stored. Defaults to None.
max_chars (int): The maximum number of characters in each document to consider for classification. Defaults to 2000.
device_type (str): The type of device to use for inference, either "cuda" or "cpu". Defaults to "cuda".
autocast (bool): Whether to use mixed precision for faster inference. Defaults to True.
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
it defaults to the available GPU memory minus 4 GB.
Parent class for DomainClassifier and MultilingualDomainClassifier,
since their implementations are almost identical.
"""

def __init__(
self,
multilingual: bool = False,
filter_by: Optional[List[str]] = None,
batch_size: int = 256,
text_field: str = "text",
Expand All @@ -97,7 +87,20 @@ def __init__(
autocast: bool = True,
max_mem_gb: Optional[int] = None,
):
config = AutoConfig.from_pretrained(DOMAIN_IDENTIFIER)
self.multilingual = multilingual

if multilingual:
config = AutoConfig.from_pretrained(MULTILINGUAL_DOMAIN_IDENTIFIER)
model_config = DomainModelConfig(
identifier=MULTILINGUAL_DOMAIN_IDENTIFIER,
base_model=MULTILINGUAL_DOMAIN_BASE_MODEL,
)
else:
config = AutoConfig.from_pretrained(DOMAIN_IDENTIFIER)
model_config = DomainModelConfig(
identifier=DOMAIN_IDENTIFIER,
base_model=DOMAIN_BASE_MODEL,
)

self.text_field = text_field
self.prob_column = prob_column
Expand All @@ -106,7 +109,7 @@ def __init__(
self.out_dim = len(self.labels)

model = DomainModel(
config=DomainModelConfig, autocast=autocast, max_mem_gb=max_mem_gb
config=model_config, autocast=autocast, max_mem_gb=max_mem_gb
)

super().__init__(
Expand All @@ -122,7 +125,11 @@ def __init__(
)

def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset:
print("Starting domain classifier inference", flush=True)
if self.multilingual:
print("Starting multilingual domain classifier inference", flush=True)
else:
print("Starting domain classifier inference", flush=True)

df = dataset.df
df = _run_classifier_helper(
df=df,
Expand All @@ -135,3 +142,98 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset:
prob_col=self.prob_column,
)
return DocumentDataset(df)


class DomainClassifier(_DomainClassifier):
"""
DomainClassifier is a specialized classifier designed for English text domain classification tasks,
utilizing the NVIDIA Domain Classifier (https://huggingface.co/nvidia/domain-classifier) model.
This class is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large datasets.
Attributes:
filter_by (list[str], optional): The classes to filter the dataset by.
If None, all classes will be included. Defaults to None.
batch_size (int): The number of samples per batch for inference. Defaults to 256.
text_field (str): The field in the dataset that should be classified.
pred_column (str): The column name where predictions will be stored. Defaults to "domain_pred".
prob_column (str, optional): The column name where prediction probabilities will be stored. Defaults to None.
max_chars (int): The maximum number of characters in each document to consider for classification. Defaults to 2000.
device_type (str): The type of device to use for inference, either "cuda" or "cpu". Defaults to "cuda".
autocast (bool): Whether to use mixed precision for faster inference. Defaults to True.
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
it defaults to the available GPU memory minus 4 GB.
"""

def __init__(
self,
filter_by: Optional[List[str]] = None,
batch_size: int = 256,
text_field: str = "text",
pred_column: str = "domain_pred",
prob_column: Optional[str] = None,
max_chars: int = 2000,
device_type: str = "cuda",
autocast: bool = True,
max_mem_gb: Optional[int] = None,
):
super().__init__(
multilingual=False,
filter_by=filter_by,
batch_size=batch_size,
text_field=text_field,
pred_column=pred_column,
prob_column=prob_column,
max_chars=max_chars,
device_type=device_type,
autocast=autocast,
max_mem_gb=max_mem_gb,
)


class MultilingualDomainClassifier(_DomainClassifier):
"""
MultilingualDomainClassifier is a specialized classifier designed for domain classification tasks,
utilizing the NVIDIA Multilingual Domain Classifier (https://huggingface.co/nvidia/multilingual-domain-classifier) model.
It supports domain classification across 52 languages.
This class is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large datasets.
Attributes:
filter_by (list[str], optional): The classes to filter the dataset by.
If None, all classes will be included. Defaults to None.
batch_size (int): The number of samples per batch for inference. Defaults to 256.
text_field (str): The field in the dataset that should be classified.
pred_column (str): The column name where predictions will be stored. Defaults to "domain_pred".
prob_column (str, optional): The column name where prediction probabilities will be stored. Defaults to None.
max_chars (int): The maximum number of characters in each document to consider for classification. Defaults to 2000.
device_type (str): The type of device to use for inference, either "cuda" or "cpu". Defaults to "cuda".
autocast (bool): Whether to use mixed precision for faster inference. Defaults to True.
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
it defaults to the available GPU memory minus 4 GB.
"""

def __init__(
self,
filter_by: Optional[List[str]] = None,
batch_size: int = 256,
text_field: str = "text",
pred_column: str = "domain_pred",
prob_column: Optional[str] = None,
max_chars: int = 2000,
device_type: str = "cuda",
autocast: bool = True,
max_mem_gb: Optional[int] = None,
):
super().__init__(
multilingual=True,
filter_by=filter_by,
batch_size=batch_size,
text_field=text_field,
pred_column=pred_column,
prob_column=prob_column,
max_chars=max_chars,
device_type=device_type,
autocast=autocast,
max_mem_gb=max_mem_gb,
)
Loading

0 comments on commit 7272ca0

Please sign in to comment.