From fde7e6dc7d583451e6338e2a2b61aed9ef2fed69 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Thu, 16 Jan 2025 12:30:04 -0800 Subject: [PATCH] Minor CrossFit improvements Signed-off-by: Sarah Yurick --- nemo_curator/classifiers/base.py | 11 +++++++---- nemo_curator/classifiers/prompt_task_complexity.py | 8 ++++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/nemo_curator/classifiers/base.py b/nemo_curator/classifiers/base.py index 4f8cdc25..f00cff44 100644 --- a/nemo_curator/classifiers/base.py +++ b/nemo_curator/classifiers/base.py @@ -121,10 +121,13 @@ def _run_classifier_helper( prob_col: str = None, ) -> "dask_cudf.DataFrame": - if prob_col: - df[prob_col] = 0 - else: + if prob_col is None: prob_col = "_prob" + labeler = op.Labeler(labels, cols=[prob_col], suffix=label_col) + else: + labeler = op.Labeler( + labels, cols=[prob_col], keep_cols=[prob_col], suffix=label_col + ) columns_to_keep_list = df.columns.to_list() @@ -138,7 +141,7 @@ def _run_classifier_helper( batch_size=batch_size, pred_output_col=prob_col, ), - op.Labeler(labels, cols=[prob_col], suffix=label_col), + labeler, repartition=df.npartitions, keep_cols=columns_to_keep_list, ) diff --git a/nemo_curator/classifiers/prompt_task_complexity.py b/nemo_curator/classifiers/prompt_task_complexity.py index 4f2c4efc..ce11760c 100644 --- a/nemo_curator/classifiers/prompt_task_complexity.py +++ b/nemo_curator/classifiers/prompt_task_complexity.py @@ -336,11 +336,15 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset: df = dataset.df columns_to_keep_list = df.columns.to_list() - df["sliced_text"] = df[self.text_field].str.slice(0, self.max_chars) model = self.model classifier_pipe = op.Sequential( - op.Tokenizer(model, cols=["sliced_text"], tokenizer_type="default"), + op.Tokenizer( + model, + cols=[self.text_field], + tokenizer_type="default", + max_chars=self.max_chars, + ), op.Predictor( model, sorted_data_loader=True,