Skip to content

Commit

Permalink
Address Reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
VibhuJawa committed May 20, 2024
1 parent 6bff4fa commit e2811e7
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 62 deletions.
File renamed without changes.
File renamed without changes.
40 changes: 23 additions & 17 deletions nemo_curator/modules/distributed_data_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@


@dataclass
class domain_Config:
class DomainModelConfig:
model = "microsoft/deberta-v3-base"
fc_dropout = 0.2
max_len = 512


@dataclass
class quality_Config:
class QualityModelConfig:
model = "microsoft/deberta-v3-base"
fc_dropout = 0.2
max_len = 512
Expand Down Expand Up @@ -170,7 +170,6 @@ def _run_classifier_helper(
prob_internal_col = "_prob"
# TODO: Make crossfit handle this cleanly
pred_internal_col = "labels"

df["sliced_text"] = df["text"].str.slice(0, max_chars)
columns_to_keep_list = df.columns.to_list()
columns_to_keep_list.remove("sliced_text")
Expand All @@ -197,12 +196,11 @@ def _run_classifier_helper(
df = labeling_pipe(df)
if keep_prob:
df = df.rename(
columns={pred_internal_col: label_col, prob_internal_col: prob_col}
columns={prob_internal_col: prob_col, pred_internal_col: label_col},
)
else:
df = df.rename(columns={pred_internal_col: label_col})
df = df.drop(columns=[prob_internal_col])

return df


Expand All @@ -229,6 +227,8 @@ def load_model(self, device="cuda"):
if version.parse(TRANSFORMERS_VERSION) >= version.parse("4.31.0"):
sd.pop("model.embeddings.position_ids", None)
model.load_state_dict(sd, strict=True)
else:
raise ValueError(f"Model path {self.model_path} does not exist")
return model.eval()

def load_tokenizer(self):
Expand All @@ -255,11 +255,14 @@ def load_model(self, device="cuda"):
autocast=self.autocast,
)
model = model.to(device)
sd = torch.load(self.model_path, map_location="cpu")
if "model_state_dict" in sd:
sd = sd["model_state_dict"]
sd = {k[7:] if k.startswith("module.") else k: sd[k] for k in sd.keys()}
model.load_state_dict(sd, strict=True)
if os.path.exists(self.model_path):
sd = torch.load(self.model_path, map_location="cpu")
if "model_state_dict" in sd:
sd = sd["model_state_dict"]
sd = {k[7:] if k.startswith("module.") else k: sd[k] for k in sd.keys()}
model.load_state_dict(sd, strict=True)
else:
raise ValueError(f"Model path {self.model_path} does not exist")
model.eval()
return model

Expand All @@ -273,23 +276,26 @@ def load_config(self):
class DomainClassifier(DistributedDataClassifier):
def __init__(
self,
model_file_name,
model_path,
labels,
filter_by=None,
batch_size=256,
out_dim=None,
pred_column="domain_pred",
prob_column=None,
max_chars=2000,
device_type="cuda",
autocast=True,
):
if out_dim is None:
out_dim = len(labels)

self.prob_column = prob_column

model = DomainModel(
config=domain_Config,
config=DomainModelConfig,
out_dim=out_dim,
model_path=model_file_name,
model_path=model_path,
autocast=autocast,
)

Expand All @@ -315,15 +321,15 @@ def _run_classifier(self, dataset: DocumentDataset):
max_chars=self.max_chars,
batch_size=self.batch_size,
label_col=self.pred_column,
keep_prob=False,
prob_col=self.prob_column,
)
return DocumentDataset(df)


class QualityClassifier(DistributedDataClassifier):
def __init__(
self,
model_file_name,
model_path,
labels,
filter_by=None,
batch_size=256,
Expand All @@ -345,9 +351,9 @@ def __init__(
self.max_len = max_len

model = QualityModel(
config=quality_Config,
config=QualityModelConfig,
out_dim=out_dim,
model_path=model_file_name,
model_path=model_path,
autocast=autocast,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
os.environ["RAPIDS_NO_INITIALIZE"] = "1"
from nemo_curator import DomainClassifier
from nemo_curator.datasets import DocumentDataset
from nemo_curator.distributed_data_classification.arg_utils import create_arg_parser

# Get relevant args
from nemo_curator.scripts.classifier_arg_utils import create_arg_parser
from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk
from nemo_curator.utils.file_utils import get_remaining_files

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
os.environ["RAPIDS_NO_INITIALIZE"] = "1"
from nemo_curator import QualityClassifier
from nemo_curator.datasets import DocumentDataset
from nemo_curator.distributed_data_classification.arg_utils import create_arg_parser
from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk
from nemo_curator.utils.file_utils import get_remaining_files

from .classifier_arg_utils import create_arg_parser

warnings.filterwarnings("ignore")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Silence Warnings (Majorly HF internal warnings)"
"## Distributed Data Classification with Quality and Domain Classifiers\n",
"\n",
"The notebook demonstrates the use of two classifiers for distributed data classification, including quality and domain classifiers. The quality classifier is used to classify the quality of the data, while the domain classifier is used to classify the domain of the data.These classifers help with annotation which helps data blending for foundation model training. \n",
"\n",
"The classifiers are accelerated using CrossFit,(https://github.com/rapidsai/crossfit), a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets."
]
},
{
Expand All @@ -21,8 +25,9 @@
}
],
"source": [
"%env PYTHONWARNINGS=ignore\n",
"#### Silence Warnings (HuggingFace internal warnings)\n",
"\n",
"%env PYTHONWARNINGS=ignore\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")"
]
Expand Down Expand Up @@ -53,17 +58,19 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data File Paths "
"# Define the data file paths "
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"input_file_path=\"/home/nfs/syurick/LLM_domain_classifier_inference/4360_results_jsonl_dir/\"\n",
"output_file_path = \"/raid/vjawa/output_file.parquet\""
"input_file_path=\"/input_data_dir/\"\n",
"output_file_path = \"output_data_dir/\"\n",
"domain_model_path = \"domain_model.pth\"\n",
"quality_model_path = \"quality_model.pth\""
]
},
{
Expand All @@ -79,7 +86,7 @@
"metadata": {},
"outputs": [],
"source": [
"classifier_type=\"DomainClassifier\""
"classifier_type=\"DomainClassifier\" # or \"QualityClassifier\""
]
},
{
Expand All @@ -98,8 +105,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 7.14 s, sys: 4.91 s, total: 12 s\n",
"Wall time: 7.95 s\n"
"CPU times: user 10.5 s, sys: 5.33 s, total: 15.8 s\n",
"Wall time: 11.4 s\n"
]
}
],
Expand Down Expand Up @@ -139,21 +146,16 @@
" \"Sports\",\n",
" \"Travel_and_Transportation\",\n",
" ]\n",
" model_file_name = \"/home/nfs/syurick/LLM_domain_classifier_inference/\" + \\\n",
" \"GoogleDebertaAgree_v3b_bce_maxlen512_bs64_noRef_best.pth\"\n",
" classifier = DomainClassifier(\n",
" model_file_name=model_file_name,\n",
" model_path=domain_model_path,\n",
" labels=domain_labels,\n",
" batch_size=1024,\n",
" )\n",
"elif classifier_type == \"QualityClassifier\":\n",
" quality_labels = [\"High\", \"Medium\", \"Low\"]\n",
" model_file_name = \"/home/nfs/syurick/LLM_quality_classifier_inference/\" + \\\n",
" \"quality_rnd3_2014val1070_10ep_2xhigh_1024_fold4_last-001.pth\"\n",
"\n",
"\n",
" model_file_name = \"quality_classifier.pth\"\n",
" classifier = QualityClassifier(\n",
" model_file_name=model_file_name,\n",
" model_path=quality_model_path,\n",
" labels=quality_labels,\n",
" batch_size=1024,\n",
" )\n",
Expand All @@ -165,12 +167,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Run the actuall Classifier"
"# Run the Classifier\n",
"\n",
"Dask operations are lazy, so the the classifier will not run until we call a eager operation like `to_json`, `compute` or `persist`. "
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -184,52 +188,51 @@
"name": "stderr",
"output_type": "stream",
"text": [
"GPU: 0, Part: 8: 100%|██████████| 937/937 [00:15<00:00, 62.45it/s]]]\n",
"GPU: 0, Part: 9: 100%|██████████| 937/937 [00:13<00:00, 68.83it/s] \n",
"GPU: 0, Part: 5: 100%|██████████| 938/938 [00:13<00:00, 67.67it/s] \n",
"GPU: 0, Part: 15: 100%|██████████| 937/937 [00:15<00:00, 62.42it/s] \n",
"GPU: 0, Part: 10: 100%|██████████| 937/937 [00:14<00:00, 64.39it/s] \n",
"GPU: 0, Part: 3: 100%|██████████| 938/938 [00:13<00:00, 67.37it/s] \n",
"GPU: 0, Part: 7: 100%|██████████| 937/937 [00:14<00:00, 66.83it/s] \n",
"GPU: 0, Part: 11: 100%|██████████| 937/937 [00:14<00:00, 65.50it/s]\n",
"GPU: 0, Part: 2: 100%|██████████| 938/938 [00:13<00:00, 68.84it/s]\n",
"GPU: 0, Part: 6: 100%|██████████| 938/938 [00:14<00:00, 64.48it/s]\n",
"GPU: 0, Part: 4: 100%|██████████| 938/938 [00:14<00:00, 63.31it/s]\n",
"GPU: 0, Part: 13: 100%|██████████| 937/937 [00:14<00:00, 63.18it/s]\n",
"GPU: 0, Part: 12: 100%|██████████| 937/937 [00:14<00:00, 62.92it/s]\n",
"GPU: 0, Part: 0: 100%|██████████| 938/938 [00:15<00:00, 62.14it/s]\n",
"GPU: 0, Part: 14: 100%|██████████| 937/937 [00:15<00:00, 62.04it/s]\n",
"GPU: 0, Part: 1: 100%|██████████| 938/938 [00:16<00:00, 56.99it/s]\n"
"GPU: 0, Part: 1: 100%|██████████| 938/938 [00:09<00:00, 101.99it/s] \n",
"GPU: 0, Part: 3: 100%|██████████| 938/938 [00:10<00:00, 92.36it/s] ]\n",
"GPU: 0, Part: 0: 100%|██████████| 938/938 [00:10<00:00, 91.25it/s] ]\n",
"GPU: 0, Part: 5: 100%|██████████| 938/938 [00:10<00:00, 88.82it/s] \n",
"GPU: 0, Part: 14: 100%|██████████| 937/937 [00:10<00:00, 88.11it/s] \n",
"GPU: 0, Part: 8: 100%|██████████| 937/937 [00:10<00:00, 85.46it/s] ]\n",
"GPU: 0, Part: 9: 100%|██████████| 937/937 [00:10<00:00, 86.16it/s] \n",
"GPU: 0, Part: 4: 100%|██████████| 938/938 [00:10<00:00, 85.65it/s]]\n",
"GPU: 0, Part: 11: 100%|██████████| 937/937 [00:11<00:00, 83.73it/s] \n",
"GPU: 0, Part: 6: 100%|██████████| 938/938 [00:11<00:00, 83.62it/s]\n",
"GPU: 0, Part: 10: 100%|██████████| 937/937 [00:11<00:00, 81.27it/s] \n",
"GPU: 0, Part: 2: 100%|██████████| 938/938 [00:12<00:00, 72.59it/s]]\n",
"GPU: 0, Part: 7: 100%|██████████| 937/937 [00:13<00:00, 71.75it/s]\n",
"GPU: 0, Part: 12: 100%|██████████| 937/937 [00:13<00:00, 69.12it/s]\n",
"GPU: 0, Part: 15: 100%|██████████| 937/937 [00:13<00:00, 68.47it/s]\n",
"GPU: 0, Part: 13: 100%|██████████| 937/937 [00:14<00:00, 66.29it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Writing to disk complete for 16 partitions\n",
"CPU times: user 4.61 s, sys: 6.09 s, total: 10.7 s\n",
"Wall time: 39.8 s\n"
"CPU times: user 2.34 s, sys: 2.24 s, total: 4.58 s\n",
"Wall time: 17.2 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"result_dataset = classifier(dataset=input_dataset)\n",
"result_dataset.df = result_dataset.df.rename(columns={\"labels\": f\"{classifier_type}_prediction\"})\n",
"result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Verify The file was written correctly"
"#### Inspect the Output"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -322,7 +325,7 @@
"1 https://oregonmassageandwellnessclinic.com/app... "
]
},
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -336,12 +339,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"##### cleanup the output file"
"##### Cleanup the output file"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand Down

0 comments on commit e2811e7

Please sign in to comment.