diff --git a/nemo_curator/distributed_data_classification/domain_classifier_inference.py b/nemo_curator/distributed_data_classification/domain_classifier_inference.py index edda23319..065675e16 100644 --- a/nemo_curator/distributed_data_classification/domain_classifier_inference.py +++ b/nemo_curator/distributed_data_classification/domain_classifier_inference.py @@ -49,7 +49,6 @@ def main(): "People_and_Society", "Pets_and_Animals", "Real_Estate", - "Reference", "Science", "Sensitive_Subjects", "Shopping", @@ -65,6 +64,10 @@ def main(): print("Starting domain classifier inference", flush=True) global_st = time.time() files_per_run = len(client.scheduler_info()["workers"]) * 2 + + if not os.path.exists(args.output_file_path): + os.makedirs(args.output_file_path) + input_files = get_remaining_files( args.input_file_path, args.output_file_path, args.input_file_type ) @@ -80,6 +83,7 @@ def main(): labels=labels, max_chars=max_chars, batch_size=args.batch_size, + out_dim=len(labels), autocast=args.autocast, ) diff --git a/nemo_curator/distributed_data_classification/quality_classifier_inference.py b/nemo_curator/distributed_data_classification/quality_classifier_inference.py index 3278d3939..a922f85a6 100644 --- a/nemo_curator/distributed_data_classification/quality_classifier_inference.py +++ b/nemo_curator/distributed_data_classification/quality_classifier_inference.py @@ -69,6 +69,10 @@ def main(): print("Starting quality classifier inference", flush=True) global_st = time.time() files_per_run = len(client.scheduler_info()["workers"]) * 2 + + if not os.path.exists(args.output_file_path): + os.makedirs(args.output_file_path) + input_files = get_remaining_files( args.input_file_path, args.output_file_path, args.input_file_type ) @@ -85,6 +89,7 @@ def main(): labels=labels, batch_size=args.batch_size, autocast=args.autocast, + out_dim=len(labels), ) for file_batch_id, i in enumerate(range(0, len(input_files), files_per_run)): @@ -122,3 +127,7 @@ def main(): def console_script(): main() + + +if __name__ == "__main__": + console_script()