diff --git a/mammoth/transforms/__init__.py b/mammoth/transforms/__init__.py index b585e216..bedc86dd 100644 --- a/mammoth/transforms/__init__.py +++ b/mammoth/transforms/__init__.py @@ -14,7 +14,7 @@ def get_transforms_cls(transform_names): transforms_cls = {} for name in transform_names: if name not in AVAILABLE_TRANSFORMS: - raise ValueError("specified tranform not supported!") + raise ValueError("Specified transform not supported!", name) transforms_cls[name] = AVAILABLE_TRANSFORMS[name] return transforms_cls diff --git a/mammoth/transforms/filtering.py b/mammoth/transforms/filtering.py new file mode 100644 index 00000000..01831681 --- /dev/null +++ b/mammoth/transforms/filtering.py @@ -0,0 +1,219 @@ +from mammoth.transforms import register_transform +from .transform import Transform, ObservableStats +import re +import math +import itertools +import string +import difflib + + +class FilterTooLongStats(ObservableStats): + """Runing statistics for FilterTooLongTransform.""" + + __slots__ = ["filtered"] + + def __init__(self): + self.filtered = 1 + + def update(self, other: "FilterTooLongStats"): + self.filtered += other.filtered + + +@register_transform(name='filtertoolong') +class FilterTooLongTransform(Transform): + """Filter out sentence that are too long.""" + + def __init__(self, opts): + super().__init__(opts) + + @classmethod + def add_options(cls, parser): + """Available options relating to this Transform.""" + group = parser.add_argument_group("Transform/Filter") + group.add("--src_seq_length", "-src_seq_length", type=int, default=200, help="Maximum source sequence length.") + group.add("--tgt_seq_length", "-tgt_seq_length", type=int, default=200, help="Maximum target sequence length.") + + def _parse_opts(self): + self.src_seq_length = self.opts.src_seq_length + self.tgt_seq_length = self.opts.tgt_seq_length + + def apply(self, example, is_train=False, stats=None, **kwargs): + """Return None if too long else return as is.""" + src_len = len(example['src']) + tgt_len = len(example['tgt']) + if src_len == 0 or tgt_len == 0: + # also filter empty strings + return None + if src_len > self.src_seq_length or tgt_len > self.tgt_seq_length: + if stats is not None: + stats.update(FilterTooLongStats()) + return None + else: + return example + + def _repr_args(self): + """Return str represent key arguments for class.""" + return '{}={}, {}={}'.format('src_seq_length', self.src_seq_length, 'tgt_seq_length', self.tgt_seq_length) + + +# Filters inspired by OpusFilter +# https://github.com/Helsinki-NLP/OpusFilter/blob/aca40bd064d9b087c5216de0568d7fb91a31d142/opusfilter/filters.py + + +@register_transform(name='filterwordratio') +class FilterWordRatio(Transform): + """Filter out sentence based on word length ratio""" + + def __init__(self, opts): + super().__init__(opts) + + @classmethod + def add_options(cls, parser): + """Available options relating to this Transform.""" + group = parser.add_argument_group("Transform/Filter") + group.add("--word_ratio_threshold", "-word_ratio_threshold", type=int, default=3, + help="Threshold for discarding sentences based on word ratio.") + + def _parse_opts(self): + self.word_ratio_threshold = self.opts.word_ratio_threshold + + def apply(self, example, **kwargs): + """Return None if too long else return as is.""" + src_len = len(example['src']) + tgt_len = len(example['tgt']) + lengths = sorted([src_len, tgt_len]) + if lengths[0] == 0: + return None + else: + ratio = lengths[-1] / lengths[0] + if ratio < self.word_ratio_threshold: + return example + else: + return None + + def _repr_args(self): + """Return str represent key arguments for class.""" + return '{}={}'.format('word_ratio_threshold', self.word_ratio_threshold) + + +@register_transform(name='filterrepetitions') +class FilterRepetitions(Transform): + """Filter segments with repeated content. Useful e.g. for filtering data generated by a low-quality NMT model.""" + + def __init__(self, opts): + super().__init__(opts) + + @classmethod + def add_options(cls, parser): + """Available options relating to this Transform.""" + group = parser.add_argument_group("Transform/Filter") + group.add("--rep_threshold", "-rep_threshold", type=int, default=2, + help="Number of times the substring is repeated.") + group.add("--rep_min_len", "-rep_min_len", type=int, default=3, + help="Minimum length of the repeated pattern.") + group.add("--rep_max_len", "-rep_max_len", type=int, default=100, + help="Maximum length of the repeated pattern.") + + def _parse_opts(self): + self.rep_threshold = self.opts.rep_threshold + self.rep_min_len = self.opts.rep_min_len + self.rep_max_len = self.opts.rep_max_len + + def apply(self, example, **kwargs): + """Return None if the repeated pattern appears more than n-threshold times.""" + # compiled regexp for finding repetitions + rstring = f'(\\S.{{{self.rep_min_len-1},{self.rep_max_len}}}?)(?: *\\1){{{self.rep_threshold},}}' + regex = re.compile(rstring) + reps = [] + for segment in example['src'], example['tgt']: + match = regex.search(' '.join(segment)) + if match: + full = match.group(0) + repeated = match.group(1) + rep = full.count(repeated) - 1 + else: + rep = 0 + reps.append(rep) + if max(reps) > self.rep_threshold: + return None + else: + return example + + def _repr_args(self): + """Return str represent key arguments for class.""" + return '{}={}, {}={}, {}={}'.format('rep_threshold', self.rep_threshold, + 'rep_min_len', self.rep_min_len, 'rep_max_len', self.rep_max_len) + + +@register_transform(name='filterterminalpunct') +class FilterTerminalPunctuation(Transform): + """Filter segments with respect to the co-occurrence of terminal punctuation marks""" + + def __init__(self, opts): + super().__init__(opts) + + @classmethod + def add_options(cls, parser): + """Available options relating to this Transform.""" + group = parser.add_argument_group("Transform/Filter") + group.add("--punct_threshold", "-punct_threshold", type=int, default=-2, + help="Minimum penalty score for discarding sentences based on their terminal punctuation signs") + + def _parse_opts(self): + self.punct_threshold = self.opts.punct_threshold + + def apply(self, example, **kwargs): + """Return None if the penalty is smaller than the threshold.""" + src = ' '.join(example['src']) + tgt = ' '.join(example['tgt']) + spun = len([c for c in src if c in ['.', '?', '!', '…']]) + tpun = len([c for c in tgt if c in ['.', '?', '!', '…']]) + score = abs(spun - tpun) + if spun > 1: + score += spun - 1 + if tpun > 1: + score += tpun - 1 + score = -math.log(score + 1) + if score >= self.punct_threshold: + return example + else: + return None + + def _repr_args(self): + """Return str represent key arguments for class.""" + return '{}={}'.format('punct_threshold', self.punct_threshold) + + +@register_transform(name='filternonzeronumerals') +class FilterNonZeroNumerals(Transform): + """Filter segments based on a similarity measure of numerals between the segments with zeros removed""" + + def __init__(self, opts): + super().__init__(opts) + + @classmethod + def add_options(cls, parser): + """Available options relating to this Transform.""" + group = parser.add_argument_group("Transform/Filter") + group.add("--nonzero_threshold", "-nonzero_threshold", type=float, default=0.5, + help="Threshold for discarding sentences based on numerals between the segments with zeros removed") + + def _parse_opts(self): + self.nonzero_threshold = self.opts.nonzero_threshold + + def apply(self, example, **kwargs): + """Return None if the penalty is smaller than the threshold.""" + src = ' '.join(example['src']) + tgt = ' '.join(example['tgt']) + nums = [[int(c) for c in sent if c in string.digits and c != '0'] for sent in [src, tgt]] + for num1, num2 in itertools.combinations(nums, 2): + seq = difflib.SequenceMatcher(None, num1, num2) + ratio = seq.ratio() + if ratio >= self.nonzero_threshold: + return example + else: + return None + + def _repr_args(self): + """Return str represent key arguments for class.""" + return '{}={}'.format('nonzero_threshold', self.nonzero_threshold) diff --git a/mammoth/transforms/misc.py b/mammoth/transforms/misc.py index b8c1e8b1..84c7be8e 100644 --- a/mammoth/transforms/misc.py +++ b/mammoth/transforms/misc.py @@ -1,55 +1,6 @@ from mammoth.utils.logging import logger from mammoth.transforms import register_transform -from .transform import Transform, ObservableStats - - -class FilterTooLongStats(ObservableStats): - """Runing statistics for FilterTooLongTransform.""" - - __slots__ = ["filtered"] - - def __init__(self): - self.filtered = 1 - - def update(self, other: "FilterTooLongStats"): - self.filtered += other.filtered - - -@register_transform(name='filtertoolong') -class FilterTooLongTransform(Transform): - """Filter out sentence that are too long.""" - - def __init__(self, opts): - super().__init__(opts) - - @classmethod - def add_options(cls, parser): - """Avalilable options relate to this Transform.""" - group = parser.add_argument_group("Transform/Filter") - group.add("--src_seq_length", "-src_seq_length", type=int, default=200, help="Maximum source sequence length.") - group.add("--tgt_seq_length", "-tgt_seq_length", type=int, default=200, help="Maximum target sequence length.") - - def _parse_opts(self): - self.src_seq_length = self.opts.src_seq_length - self.tgt_seq_length = self.opts.tgt_seq_length - - def apply(self, example, is_train=False, stats=None, **kwargs): - """Return None if too long else return as is.""" - src_len = len(example['src']) - tgt_len = len(example['tgt']) - if src_len == 0 or tgt_len == 0: - # also filter empty strings - return None - if src_len > self.src_seq_length or tgt_len > self.tgt_seq_length: - if stats is not None: - stats.update(FilterTooLongStats()) - return None - else: - return example - - def _repr_args(self): - """Return str represent key arguments for class.""" - return '{}={}, {}={}'.format('src_seq_length', self.src_seq_length, 'tgt_seq_length', self.tgt_seq_length) +from .transform import Transform @register_transform(name='prefix')