Skip to content

Commit

Permalink
Add batched decorator (#18)
Browse files Browse the repository at this point in the history
* Add batched decorator

Signed-off-by: Ryan Wolf <[email protected]>

* Fix stray legacy batched statements

Signed-off-by: Ryan Wolf <[email protected]>

---------

Signed-off-by: Ryan Wolf <[email protected]>
  • Loading branch information
ryantwolf authored Apr 2, 2024
1 parent 4346c74 commit ccf107a
Show file tree
Hide file tree
Showing 15 changed files with 94 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ You could read, de-identify the dataset, and write it to an output directory usi
from nemo_curator.utils.distributed_utils import read_data, write_to_disk, get_client
from nemo_curator.utils.file_utils import get_batched_files
from nemo_curator.modules.modify import Modify
from nemo_curator.modifiers.pii_modifier import PiiModifierBatched
from nemo_curator.modifiers.pii_modifier import PiiModifier
modifier = PiiModifierBatched(
modifier = PiiModifier(
language="en",
supported_entities=["PERSON", "EMAIL_ADDRESS"],
anonymize_action="replace",
Expand All @@ -70,7 +70,7 @@ You could read, de-identify the dataset, and write it to an output directory usi
dataset = DocumentDataset(source_data)
print(f"Dataset has {source_data.npartitions} partitions")
modify = Modify(modifier, batched=True)
modify = Modify(modifier)
modified_dataset = modify(dataset)
write_to_disk(modified_dataset.df,
"output_directory",
Expand All @@ -80,11 +80,11 @@ You could read, de-identify the dataset, and write it to an output directory usi
Let's walk through this code line by line.

* ``modifier = PiiModifierBatched`` creates an instance of ``PiiModifierBatched`` class that is responsible for PII de-identification
* ``modifier = PiiModifier`` creates an instance of ``PiiModifier`` class that is responsible for PII de-identification
* ``for file_names in get_batched_files`` retrieves a batch of 32 documents from the `book_dataset`
* ``source_data = read_data(file_names, file_type="jsonl", backend='pandas', add_filename=True)`` reads the data from all the files using Dask using Pandas as the backend. The ``add_filename`` argument ensures that the output files have the same filename as the input files.
* ``dataset = DocumentDataset(source_data)`` creates an instance of ``DocumentDataset`` using the batch files. ``DocumentDataset`` is the standard format for text datasets in NeMo Curator.
* ``modify = Modify(modifier, batched=True)`` creates an instance of the ``Modify`` class. This class can take any modifier as an argument
* ``modify = Modify(modifier)`` creates an instance of the ``Modify`` class. This class can take any modifier as an argument
* ``modified_dataset = modify(dataset)`` modifies the data in the dataset by performing the PII de-identification based upon the passed parameters.
* ``write_to_disk(modified_dataset.df ....`` writes the de-identified documents to disk.

Expand Down
5 changes: 2 additions & 3 deletions examples/classifier_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import nemo_curator as nc
from nemo_curator.datasets import DocumentDataset
from nemo_curator.filters import BatchedFastTextQualityFilter
from nemo_curator.filters import FastTextQualityFilter
from nemo_curator.modifiers import FastTextLabelModifier
from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk
from nemo_curator.utils.file_utils import get_all_files_paths_under
Expand Down Expand Up @@ -85,9 +85,8 @@ def main(args):
# Filter data
target_dataset = load_dataset(low_quality_data_path)
filter_pipeline = nc.ScoreFilter(
BatchedFastTextQualityFilter(model_path),
FastTextQualityFilter(model_path),
score_field="quality_score",
batched=True,
score_type=float,
)
filtered_dataset = filter_pipeline(target_dataset)
Expand Down
6 changes: 3 additions & 3 deletions examples/find_pii_and_deidentify.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pandas as pd

from nemo_curator.datasets import DocumentDataset
from nemo_curator.modifiers.pii_modifier import PiiModifierBatched
from nemo_curator.modifiers.pii_modifier import PiiModifier
from nemo_curator.modules.modify import Modify
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.script_utils import add_distributed_args
Expand All @@ -35,15 +35,15 @@ def console_script():
dd = dask.dataframe.from_pandas(dataframe, npartitions=1)
dataset = DocumentDataset(dd)

modifier = PiiModifierBatched(
modifier = PiiModifier(
log_dir="./logs",
batch_size=2000,
language="en",
supported_entities=["PERSON", "EMAIL_ADDRESS"],
anonymize_action="replace",
)

modify = Modify(modifier, batched=True)
modify = Modify(modifier)
modified_dataset = modify(dataset)
modified_dataset.df.to_json("output_files/*.jsonl", lines=True, orient="records")

Expand Down
7 changes: 1 addition & 6 deletions nemo_curator/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .classifier_filter import (
BatchedFastTextQualityFilter,
FastTextLangId,
FastTextQualityFilter,
)
from .classifier_filter import FastTextLangId, FastTextQualityFilter
from .code import (
AlphaFilter,
GeneralCommentToCodeFilter,
Expand Down Expand Up @@ -54,7 +50,6 @@
)

__all__ = [
"BatchedFastTextQualityFilter",
"DocumentFilter",
"import_filter",
"FastTextLangId",
Expand Down
59 changes: 15 additions & 44 deletions nemo_curator/filters/classifier_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pandas as pd

from nemo_curator.filters.doc_filter import DocumentFilter
from nemo_curator.utils.decorators import batched
from nemo_curator.utils.distributed_utils import NoWorkerError, load_object_on_worker


Expand All @@ -34,42 +35,7 @@ def __init__(self, model_path=None, label="__label__hq", alpha=3, seed=42):
self._seed = np.random.seed(seed)
self._name = "fasttext_quality_filter"

def score_document(self, text):
text = text.replace("\n", " ").replace("__label__", " ")
model_attr = f"{self._name}_{self._model_path}"
# Workers don't exist during type inference
try:
model = load_object_on_worker(model_attr, self._load_model, {})
except NoWorkerError:
return 1.0
pred = model.predict(text)
document_score = pred[1][0]
if pred[0][0] != self._label:
document_score = 1 - document_score

return document_score

def keep_document(self, score):
return np.random.pareto(self._alpha) > 1 - score

def _load_model(self):
return fasttext.load_model(self._model_path)


class BatchedFastTextQualityFilter(DocumentFilter):

def __init__(self, model_path=None, label="__label__hq", alpha=3, seed=42):
if model_path is None:
raise ValueError(
"Must provide a valid path to a FastText model "
"to compute document scores with this filter"
)
self._model_path = model_path
self._label = label
self._alpha = alpha
self._seed = np.random.seed(seed)
self._name = "fasttext_quality_filter"

@batched
def score_document(self, df):
model_attr = f"{self._name}_{self._model_path}"
try:
Expand All @@ -88,6 +54,7 @@ def _score_document(text):

return df.apply(_score_document)

@batched
def keep_document(self, df):
return np.random.pareto(self._alpha, size=len(df)) > 1 - df

Expand All @@ -108,19 +75,23 @@ def __init__(self, model_path=None, min_langid_score=0.3):
self._cutoff = min_langid_score
self._name = "lang_id"

def score_document(self, text):
pp = text.strip().replace("\n", " ")

@batched
def score_document(self, df):
model_attr = f"{self._name}_{self._model_path}"
try:
model = load_object_on_worker(model_attr, self._load_model, {})
except NoWorkerError:
return [1.0, "N/A"]
label, score = model.predict(pp, k=1)
score = score[0]
lang_code = label[0][-2:].upper()
return pd.Series([[1.0, "N/A"] for _ in range(len(df))])

return [score, lang_code]
def _score_document(text):
pp = text.strip().replace("\n", " ")
label, score = model.predict(pp, k=1)
score = score[0]
lang_code = label[0][-2:].upper()

return [score, lang_code]

return df.apply(_score_document)

def keep_document(self, score):
return score[0] >= self._cutoff
Expand Down
10 changes: 6 additions & 4 deletions nemo_curator/modifiers/pii_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@

from nemo_curator.modifiers import DocumentModifier
from nemo_curator.pii.algorithm import DEFAULT_LANGUAGE
from nemo_curator.utils.decorators import batched
from nemo_curator.utils.distributed_utils import load_object_on_worker

__all__ = ["PiiModifierBatched"]
__all__ = ["PiiModifier"]

DEFAULT_BATCH_SIZE = 2000


class PiiModifierBatched(DocumentModifier):
class PiiModifier(DocumentModifier):
"""
This class is the entry point to using the PII de-identification module on documents stored as CSV, JSONL or
other formats. It works with the `Modify` functionality as shown below:
Expand All @@ -34,13 +35,13 @@ class PiiModifierBatched(DocumentModifier):
dd = dask.dataframe.from_pandas(dataframe, npartitions=1)
dataset = DocumentDataset(dd)
modifier = PiiModifierBatched(
modifier = PiiModifier(
batch_size=2000,
language='en',
supported_entities=['PERSON', "EMAIL_ADDRESS"],
anonymize_action='replace')
modify = Modify(modifier, batched=True)
modify = Modify(modifier)
modified_dataset = modify(dataset)
modified_dataset.df.to_json('output_files/*.jsonl', lines=True, orient='records')
Expand All @@ -65,6 +66,7 @@ def __init__(
self.batch_size = batch_size
self.device = device

@batched
def modify_document(self, text: pd.Series, partition_info: Dict = None):
import logging

Expand Down
19 changes: 7 additions & 12 deletions nemo_curator/modules/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
from dask.typing import no_default

from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.module_utils import is_batched


class Score:
def __init__(
self, score_fn, score_field, text_field="text", batched=False, score_type=None
):
def __init__(self, score_fn, score_field, text_field="text", score_type=None):
"""
Args:
score_fn: The score function that takes in a document string and outputs a score for the document
Expand All @@ -30,7 +29,6 @@ def __init__(
self.score_fn = score_fn
self.score_field = score_field
self.text_field = text_field
self.batched = batched
self.score_type = score_type

def __call__(self, dataset):
Expand All @@ -40,7 +38,7 @@ def __call__(self, dataset):
else:
meta = no_default

if self.batched:
if is_batched(self.score_fn):
dataset.df[self.score_field] = dataset.df[self.text_field].map_partitions(
self.score_fn, meta=meta
)
Expand All @@ -53,7 +51,7 @@ def __call__(self, dataset):


class Filter:
def __init__(self, filter_fn, filter_field, invert=False, batched=False):
def __init__(self, filter_fn, filter_field, invert=False):
"""
Args:
filter_fn: A function that returns True if the document is to be kept
Expand All @@ -63,10 +61,9 @@ def __init__(self, filter_fn, filter_field, invert=False, batched=False):
self.filter_fn = filter_fn
self.filter_field = filter_field
self.invert = invert
self.batched = batched

def __call__(self, dataset):
if self.batched:
if is_batched(self.filter_fn):
bool_mask = dataset.df[self.filter_field].map_partitions(
self.filter_fn, meta=(None, bool)
)
Expand All @@ -89,7 +86,6 @@ def __init__(
score_field=None,
score_type=None,
invert=False,
batched=False,
):
"""
Args:
Expand All @@ -100,7 +96,6 @@ def __init__(
self.score_field = score_field
self.score_type = score_type
self.invert = invert
self.batched = batched

def __call__(self, dataset):
# Set the metadata for the function calls if provided
Expand All @@ -109,7 +104,7 @@ def __call__(self, dataset):
else:
meta = no_default

if self.batched:
if is_batched(self.filter_obj.score_document):
scores = dataset.df[self.text_field].map_partitions(
self.filter_obj.score_document, meta=meta
)
Expand All @@ -121,7 +116,7 @@ def __call__(self, dataset):
if self.score_field is not None:
dataset.df[self.score_field] = scores

if self.batched:
if is_batched(self.filter_obj.keep_document):
bool_mask = scores.map_partitions(
self.filter_obj.keep_document, meta=(None, bool)
)
Expand Down
6 changes: 3 additions & 3 deletions nemo_curator/modules/modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@

from nemo_curator.datasets import DocumentDataset
from nemo_curator.modifiers import DocumentModifier
from nemo_curator.utils.module_utils import is_batched


class Modify:
def __init__(self, modifier: DocumentModifier, text_field="text", batched=False):
def __init__(self, modifier: DocumentModifier, text_field="text"):
self.modifier = modifier
self.text_field = text_field
self.batched = batched

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
if self.batched:
if is_batched(self.modifier.modify_document):
dataset.df[self.text_field] = dataset.df[self.text_field].map_partitions(
self.modifier.modify_document, meta=(None, str)
)
Expand Down
2 changes: 1 addition & 1 deletion nemo_curator/modules/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
tasks = [tasks]
self.tasks = tasks
self.text_field = text_field
self.max_ngram_size = 13
self.max_ngram_size = max_ngram_size
self.max_matches = max_matches
self.min_document_length = min_document_length
self.remove_char_each_side = remove_char_each_side
Expand Down
Empty file.
6 changes: 3 additions & 3 deletions nemo_curator/scripts/find_pii_and_deidentify.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pathlib import Path

from nemo_curator.datasets import DocumentDataset
from nemo_curator.modifiers.pii_modifier import PiiModifierBatched
from nemo_curator.modifiers.pii_modifier import PiiModifier
from nemo_curator.modules.modify import Modify

# from nemo_curator.pii.algorithm import DEFAULT_LANGUAGE
Expand All @@ -43,7 +43,7 @@ def main(args):
args.supported_entities.split(",") if args.supported_entities else None
)

modifier = PiiModifierBatched(
modifier = PiiModifier(
language=args.language,
supported_entities=supported_entities,
anonymize_action=args.anonymize_action,
Expand All @@ -68,7 +68,7 @@ def main(args):
dataset = DocumentDataset(source_data)
logging.debug(f"Dataset has {source_data.npartitions} partitions")

modify = Modify(modifier, batched=True)
modify = Modify(modifier)
modified_dataset = modify(dataset)
write_to_disk(
modified_dataset.df,
Expand Down
Loading

0 comments on commit ccf107a

Please sign in to comment.