Skip to content

Commit

Permalink
Minor CrossFit improvements
Browse files Browse the repository at this point in the history
Signed-off-by: Sarah Yurick <[email protected]>
  • Loading branch information
sarahyurick committed Jan 16, 2025
1 parent 7cfda44 commit fde7e6d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
11 changes: 7 additions & 4 deletions nemo_curator/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,13 @@ def _run_classifier_helper(
prob_col: str = None,
) -> "dask_cudf.DataFrame":

if prob_col:
df[prob_col] = 0
else:
if prob_col is None:
prob_col = "_prob"
labeler = op.Labeler(labels, cols=[prob_col], suffix=label_col)
else:
labeler = op.Labeler(
labels, cols=[prob_col], keep_cols=[prob_col], suffix=label_col
)

columns_to_keep_list = df.columns.to_list()

Expand All @@ -138,7 +141,7 @@ def _run_classifier_helper(
batch_size=batch_size,
pred_output_col=prob_col,
),
op.Labeler(labels, cols=[prob_col], suffix=label_col),
labeler,
repartition=df.npartitions,
keep_cols=columns_to_keep_list,
)
Expand Down
8 changes: 6 additions & 2 deletions nemo_curator/classifiers/prompt_task_complexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,15 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset:

df = dataset.df
columns_to_keep_list = df.columns.to_list()
df["sliced_text"] = df[self.text_field].str.slice(0, self.max_chars)

model = self.model
classifier_pipe = op.Sequential(
op.Tokenizer(model, cols=["sliced_text"], tokenizer_type="default"),
op.Tokenizer(
model,
cols=[self.text_field],
tokenizer_type="default",
max_chars=self.max_chars,
),
op.Predictor(
model,
sorted_data_loader=True,
Expand Down

0 comments on commit fde7e6d

Please sign in to comment.