Skip to content

Commit

Permalink
[REVIEW] Switch Models to use Crossfit (#58)
Browse files Browse the repository at this point in the history
Switch Models to use Crossfit
  • Loading branch information
VibhuJawa authored May 21, 2024
1 parent ecd4f4b commit 9f8578b
Show file tree
Hide file tree
Showing 17 changed files with 964 additions and 1,310 deletions.
4 changes: 2 additions & 2 deletions docs/user-guide/DistributedDataClassification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ Let's see how ``DomainClassifier`` works in a small excerpt taken from ``example
"Travel_and_Transportation",
]
model_file_name = "pytorch_model_file.pth"
model_path = "pytorch_model_file.pth"
files = get_all_files_paths_under("books_dataset/")
input_dataset = DocumentDataset.read_json(files, backend="cudf", add_filename=True)
domain_classifier = DomainClassifier(
model_file_name=model_file_name,
model_path=model_path,
labels=labels,
filter_by=["Games", "Sports"],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def main(args):
"Travel_and_Transportation",
]

model_file_name = "/path/to/pytorch_model_file.pth"
model_path = "/path/to/pytorch_model_file.pth"

# Input can be a string or list
input_file_path = "/path/to/data"
Expand All @@ -66,7 +66,7 @@ def main(args):
)

domain_classifier = DomainClassifier(
model_file_name=model_file_name,
model_path=model_path,
labels=labels,
filter_by=["Games", "Sports"],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@ def main(args):
global_st = time.time()

labels = ["High", "Medium", "Low"]
model_file_name = "/path/to/pytorch_model_file.pth"
model_path = "/path/to/pytorch_model_file.pth"

# Input can be a string or list
input_file_path = "/path/to/data"
output_file_path = "./"

client = get_client(args, cluster_type=args.device)

input_dataset = DocumentDataset.from_json(
input_dataset = DocumentDataset.read_json(
input_file_path, backend="cudf", add_filename=True
)

quality_classifier = QualityClassifier(
model_file_name=model_file_name,
model_path=model_path,
labels=labels,
filter_by=["High", "Medium"],
)
Expand Down
13 changes: 0 additions & 13 deletions nemo_curator/distributed_data_classification/__init__.py

This file was deleted.

163 changes: 0 additions & 163 deletions nemo_curator/distributed_data_classification/arg_utils.py

This file was deleted.

Loading

0 comments on commit 9f8578b

Please sign in to comment.