Skip to content

Commit

Permalink
Update based on reviews
Browse files Browse the repository at this point in the history
Signed-off-by: Vibhu Jawa <[email protected]>
  • Loading branch information
VibhuJawa committed May 21, 2024
1 parent 02fe2ef commit ca713c8
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 261 deletions.
2 changes: 1 addition & 1 deletion examples/domain_classifier_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def main(args):
)

domain_classifier = DomainClassifier(
model_file_name=model_file_name,
model_path=model_file_name,
labels=labels,
filter_by=["Games", "Sports"],
)
Expand Down
2 changes: 1 addition & 1 deletion examples/quality_classifier_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def main(args):
)

quality_classifier = QualityClassifier(
model_file_name=model_file_name,
model_path=model_file_name,
labels=labels,
filter_by=["High", "Medium"],
)
Expand Down
81 changes: 0 additions & 81 deletions nemo_curator/modules/distributed_data_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,53 +97,6 @@ def forward(self, batch):
return self._forward(batch)


class CustomModel(nn.Module):
def __init__(
self, config, out_dim, config_path=None, pretrained=False, autocast=False
):
super().__init__()
self.config = config
if config_path is None:
self.config = AutoConfig.from_pretrained(
config.model, output_hidden_states=True
)
else:
self.config = torch.load(config_path)
if pretrained:
self.model = AutoModel.from_pretrained(config.model, config=self.config)
else:
self.model = AutoModel(self.config)
self.fc_dropout = nn.Dropout(config.fc_dropout)
self.fc = nn.Linear(self.config.hidden_size, out_dim)
self._init_weights(self.fc)
self.autocast = autocast

def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def feature(self, input_ids, attention_mask):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_states = outputs[0]
return last_hidden_states

def forward(self, batch):
if self.autocast:
with torch.autocast(device_type="cuda"):
return self._forward(batch)
else:
return self._forward(batch)


class DistributedDataClassifier(ABC):
"""Abstract class for running multi-node multi-GPU data classification"""

Expand Down Expand Up @@ -314,38 +267,6 @@ def load_config(self):
return AutoConfig.from_pretrained(self.path_or_name)


class QualityModel(HFModel):
def __init__(self, config, out_dim=None, model_path=None, autocast=False):
self.config = config
self.out_dim = out_dim
self.model_path = model_path
self.autocast = autocast
super().__init__(self.config.model)

def load_model(self, device="cuda"):
model = CustomModel(
self.config,
out_dim=self.out_dim,
config_path=None,
pretrained=True,
autocast=self.autocast,
)
model = model.to(device)
sd = torch.load(self.model_path, map_location="cpu")
if "model_state_dict" in sd:
sd = sd["model_state_dict"]
sd = {k[7:] if k.startswith("module.") else k: sd[k] for k in sd.keys()}
model.load_state_dict(sd, strict=True)
model.eval()
return model

def load_tokenizer(self):
return DebertaV2TokenizerFast.from_pretrained(self.config.model)

def load_config(self):
return AutoConfig.from_pretrained(self.path_or_name)


class DomainClassifier(DistributedDataClassifier):
def __init__(
self,
Expand Down Expand Up @@ -412,7 +333,6 @@ def __init__(
max_chars=6000,
device_type="cuda",
autocast=True,
max_len=1024,
):
if len(labels) == 2:
out_dim = 1 # Binary classification
Expand All @@ -421,7 +341,6 @@ def __init__(
out_dim = len(labels) # Multiclass classification

self.prob_column = prob_column
self.max_len = max_len

model = QualityModel(
config=QualityModelConfig,
Expand Down
163 changes: 0 additions & 163 deletions nemo_curator/scripts/classifier_arg_utils.py

This file was deleted.

14 changes: 7 additions & 7 deletions nemo_curator/scripts/domain_classifier_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from nemo_curator.datasets import DocumentDataset

# Get relevant args
from nemo_curator.scripts.classifier_arg_utils import create_arg_parser
from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk
from nemo_curator.utils.file_utils import get_remaining_files
from nemo_curator.utils.script_utils import parse_distributed_classifier_args

warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -58,7 +58,7 @@ def main():
"Travel_and_Transportation",
]

args = create_arg_parser().parse_args()
args = parse_distributed_classifier_args().parse_args()
print(f"Arguments parsed = {args}", flush=True)
max_chars = 2000

Expand All @@ -67,11 +67,11 @@ def main():
global_st = time.time()
files_per_run = len(client.scheduler_info()["workers"]) * 2

if not os.path.exists(args.output_file_path):
os.makedirs(args.output_file_path)
if not os.path.exists(args.output_data_dir):
os.makedirs(args.output_data_dir)

input_files = get_remaining_files(
args.input_file_path, args.output_file_path, args.input_file_type
args.input_data_dir, args.output_data_dir, args.input_file_type
)
print(f"Total input files {len(input_files)}", flush=True)

Expand All @@ -81,7 +81,7 @@ def main():
add_filename = True

domain_classifier = DomainClassifier(
model_file_name=args.model_file_name,
model_path=args.model_path,
labels=labels,
max_chars=max_chars,
batch_size=args.batch_size,
Expand All @@ -106,7 +106,7 @@ def main():

write_to_disk(
df=df,
output_file_dir=args.output_file_path,
output_file_dir=args.output_data_dir,
write_to_filename=add_filename,
)
batch_et = time.time()
Expand Down
15 changes: 7 additions & 8 deletions nemo_curator/scripts/quality_classifier_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk
from nemo_curator.utils.file_utils import get_remaining_files

from .classifier_arg_utils import create_arg_parser
from nemo_curator.utils.script_utils import parse_distributed_classifier_args

warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -59,7 +58,7 @@ def get_labels(num_labels):


def main():
parser = create_arg_parser()
parser = parse_distributed_classifier_args()
parser = add_quality_model_specific_args(parser)
args = parser.parse_args()
labels = get_labels(args.num_labels)
Expand All @@ -71,11 +70,11 @@ def main():
global_st = time.time()
files_per_run = len(client.scheduler_info()["workers"]) * 2

if not os.path.exists(args.output_file_path):
os.makedirs(args.output_file_path)
if not os.path.exists(args.output_data_dir):
os.makedirs(args.output_data_dir)

input_files = get_remaining_files(
args.input_file_path, args.output_file_path, args.input_file_type
args.input_data_dir, args.output_data_dir, args.input_file_type
)
print(f"Total input files {len(input_files)}", flush=True)

Expand All @@ -85,7 +84,7 @@ def main():
add_filename = True

classifier = QualityClassifier(
model_file_name=args.model_file_name,
model_path=args.model_path,
max_chars=max_chars,
labels=labels,
batch_size=args.batch_size,
Expand All @@ -109,7 +108,7 @@ def main():
df = classifier(DocumentDataset(df)).df
write_to_disk(
df=df,
output_file_dir=args.output_file_path,
output_file_dir=args.output_data_dir,
write_to_filename=add_filename,
)
batch_et = time.time()
Expand Down
Loading

0 comments on commit ca713c8

Please sign in to comment.