Skip to content

Commit

Permalink
Merge pull request #22 from Helsinki-NLP/feat/filtering
Browse files Browse the repository at this point in the history
Feat/filtering
  • Loading branch information
onadegibert authored Sep 27, 2023
2 parents 22fea93 + d37ed98 commit 61910fc
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 51 deletions.
2 changes: 1 addition & 1 deletion mammoth/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_transforms_cls(transform_names):
transforms_cls = {}
for name in transform_names:
if name not in AVAILABLE_TRANSFORMS:
raise ValueError("specified tranform not supported!")
raise ValueError("Specified transform not supported!", name)
transforms_cls[name] = AVAILABLE_TRANSFORMS[name]
return transforms_cls

Expand Down
219 changes: 219 additions & 0 deletions mammoth/transforms/filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
from mammoth.transforms import register_transform
from .transform import Transform, ObservableStats
import re
import math
import itertools
import string
import difflib


class FilterTooLongStats(ObservableStats):
"""Runing statistics for FilterTooLongTransform."""

__slots__ = ["filtered"]

def __init__(self):
self.filtered = 1

def update(self, other: "FilterTooLongStats"):
self.filtered += other.filtered


@register_transform(name='filtertoolong')
class FilterTooLongTransform(Transform):
"""Filter out sentence that are too long."""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Available options relating to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--src_seq_length", "-src_seq_length", type=int, default=200, help="Maximum source sequence length.")
group.add("--tgt_seq_length", "-tgt_seq_length", type=int, default=200, help="Maximum target sequence length.")

def _parse_opts(self):
self.src_seq_length = self.opts.src_seq_length
self.tgt_seq_length = self.opts.tgt_seq_length

def apply(self, example, is_train=False, stats=None, **kwargs):
"""Return None if too long else return as is."""
src_len = len(example['src'])
tgt_len = len(example['tgt'])
if src_len == 0 or tgt_len == 0:
# also filter empty strings
return None
if src_len > self.src_seq_length or tgt_len > self.tgt_seq_length:
if stats is not None:
stats.update(FilterTooLongStats())
return None
else:
return example

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}, {}={}'.format('src_seq_length', self.src_seq_length, 'tgt_seq_length', self.tgt_seq_length)


# Filters inspired by OpusFilter
# https://github.com/Helsinki-NLP/OpusFilter/blob/aca40bd064d9b087c5216de0568d7fb91a31d142/opusfilter/filters.py


@register_transform(name='filterwordratio')
class FilterWordRatio(Transform):
"""Filter out sentence based on word length ratio"""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Available options relating to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--word_ratio_threshold", "-word_ratio_threshold", type=int, default=3,
help="Threshold for discarding sentences based on word ratio.")

def _parse_opts(self):
self.word_ratio_threshold = self.opts.word_ratio_threshold

def apply(self, example, **kwargs):
"""Return None if too long else return as is."""
src_len = len(example['src'])
tgt_len = len(example['tgt'])
lengths = sorted([src_len, tgt_len])
if lengths[0] == 0:
return None
else:
ratio = lengths[-1] / lengths[0]
if ratio < self.word_ratio_threshold:
return example
else:
return None

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}'.format('word_ratio_threshold', self.word_ratio_threshold)


@register_transform(name='filterrepetitions')
class FilterRepetitions(Transform):
"""Filter segments with repeated content. Useful e.g. for filtering data generated by a low-quality NMT model."""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Available options relating to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--rep_threshold", "-rep_threshold", type=int, default=2,
help="Number of times the substring is repeated.")
group.add("--rep_min_len", "-rep_min_len", type=int, default=3,
help="Minimum length of the repeated pattern.")
group.add("--rep_max_len", "-rep_max_len", type=int, default=100,
help="Maximum length of the repeated pattern.")

def _parse_opts(self):
self.rep_threshold = self.opts.rep_threshold
self.rep_min_len = self.opts.rep_min_len
self.rep_max_len = self.opts.rep_max_len

def apply(self, example, **kwargs):
"""Return None if the repeated pattern appears more than n-threshold times."""
# compiled regexp for finding repetitions
rstring = f'(\\S.{{{self.rep_min_len-1},{self.rep_max_len}}}?)(?: *\\1){{{self.rep_threshold},}}'
regex = re.compile(rstring)
reps = []
for segment in example['src'], example['tgt']:
match = regex.search(' '.join(segment))
if match:
full = match.group(0)
repeated = match.group(1)
rep = full.count(repeated) - 1
else:
rep = 0
reps.append(rep)
if max(reps) > self.rep_threshold:
return None
else:
return example

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}, {}={}, {}={}'.format('rep_threshold', self.rep_threshold,
'rep_min_len', self.rep_min_len, 'rep_max_len', self.rep_max_len)


@register_transform(name='filterterminalpunct')
class FilterTerminalPunctuation(Transform):
"""Filter segments with respect to the co-occurrence of terminal punctuation marks"""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Available options relating to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--punct_threshold", "-punct_threshold", type=int, default=-2,
help="Minimum penalty score for discarding sentences based on their terminal punctuation signs")

def _parse_opts(self):
self.punct_threshold = self.opts.punct_threshold

def apply(self, example, **kwargs):
"""Return None if the penalty is smaller than the threshold."""
src = ' '.join(example['src'])
tgt = ' '.join(example['tgt'])
spun = len([c for c in src if c in ['.', '?', '!', '…']])
tpun = len([c for c in tgt if c in ['.', '?', '!', '…']])
score = abs(spun - tpun)
if spun > 1:
score += spun - 1
if tpun > 1:
score += tpun - 1
score = -math.log(score + 1)
if score >= self.punct_threshold:
return example
else:
return None

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}'.format('punct_threshold', self.punct_threshold)


@register_transform(name='filternonzeronumerals')
class FilterNonZeroNumerals(Transform):
"""Filter segments based on a similarity measure of numerals between the segments with zeros removed"""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Available options relating to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--nonzero_threshold", "-nonzero_threshold", type=float, default=0.5,
help="Threshold for discarding sentences based on numerals between the segments with zeros removed")

def _parse_opts(self):
self.nonzero_threshold = self.opts.nonzero_threshold

def apply(self, example, **kwargs):
"""Return None if the penalty is smaller than the threshold."""
src = ' '.join(example['src'])
tgt = ' '.join(example['tgt'])
nums = [[int(c) for c in sent if c in string.digits and c != '0'] for sent in [src, tgt]]
for num1, num2 in itertools.combinations(nums, 2):
seq = difflib.SequenceMatcher(None, num1, num2)
ratio = seq.ratio()
if ratio >= self.nonzero_threshold:
return example
else:
return None

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}'.format('nonzero_threshold', self.nonzero_threshold)
51 changes: 1 addition & 50 deletions mammoth/transforms/misc.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,6 @@
from mammoth.utils.logging import logger
from mammoth.transforms import register_transform
from .transform import Transform, ObservableStats


class FilterTooLongStats(ObservableStats):
"""Runing statistics for FilterTooLongTransform."""

__slots__ = ["filtered"]

def __init__(self):
self.filtered = 1

def update(self, other: "FilterTooLongStats"):
self.filtered += other.filtered


@register_transform(name='filtertoolong')
class FilterTooLongTransform(Transform):
"""Filter out sentence that are too long."""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Avalilable options relate to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--src_seq_length", "-src_seq_length", type=int, default=200, help="Maximum source sequence length.")
group.add("--tgt_seq_length", "-tgt_seq_length", type=int, default=200, help="Maximum target sequence length.")

def _parse_opts(self):
self.src_seq_length = self.opts.src_seq_length
self.tgt_seq_length = self.opts.tgt_seq_length

def apply(self, example, is_train=False, stats=None, **kwargs):
"""Return None if too long else return as is."""
src_len = len(example['src'])
tgt_len = len(example['tgt'])
if src_len == 0 or tgt_len == 0:
# also filter empty strings
return None
if src_len > self.src_seq_length or tgt_len > self.tgt_seq_length:
if stats is not None:
stats.update(FilterTooLongStats())
return None
else:
return example

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}, {}={}'.format('src_seq_length', self.src_seq_length, 'tgt_seq_length', self.tgt_seq_length)
from .transform import Transform


@register_transform(name='prefix')
Expand Down

0 comments on commit 61910fc

Please sign in to comment.