diff --git a/onmt/transforms/filtering.py b/onmt/transforms/filtering.py index 46807af1..d0055a8d 100644 --- a/onmt/transforms/filtering.py +++ b/onmt/transforms/filtering.py @@ -83,4 +83,50 @@ def apply(self, example, **kwargs): def _repr_args(self): """Return str represent key arguments for class.""" - return '{}={}'.format('word_ratio_threshold', self.word_ratio_threshold) \ No newline at end of file + return '{}={}'.format('word_ratio_threshold', self.word_ratio_threshold) + +@register_transform(name='filterrepetitions') +class FilterRepetitions(Transform): + """Filter segments with repeated content. Useful e.g. for filtering data generated by a low-quality NMT model.""" + + def __init__(self, opts): + super().__init__(opts) + + @classmethod + def add_options(cls, parser): + """Avalilable options relate to this Transform.""" + group = parser.add_argument_group("Transform/Filter") + group.add("--rep_threshold", "-rep_threshold", type=int, default=2, help="Number of times the substring is repeated.") + group.add("--rep_min_len", "-rep_min_len", type=int, default=3, help="Minimum length of the repeated pattern.") + group.add("--rep_max_len", "-rep_max_len", type=int, default=100, help="Maximum length of the repeated pattern.") + + def _parse_opts(self): + self.rep_threshold = self.opts.rep_threshold + self.rep_min_len = self.opts.rep_min_len + self.rep_max_len = self.opts.rep_max_len + + def apply(self, example, **kwargs): + """Return None if the repeated pattern appears more than n-threshold times.""" + # compiled regexp for finding repetitions + rstring = f'(\\S.{{{self.rep_min_len-1},{self.rep_max_len}}}?)(?: *\\1){{{self.rep_threshold},}}' + regex = re.compile(rstring) + reps = [] + for segment in example['src'], example['tgt']: + match = regex.search(' '.join(segment)) + print(match) + if match: + full = match.group(0) + repeated = match.group(1) + rep = full.count(repeated) - 1 + else: + rep = 0 + reps.append(rep) + print(reps) + if max(reps) > self.rep_threshold: + return None + else: + return example + + def _repr_args(self): + """Return str represent key arguments for class.""" + return '{}={}, {}={}, {}={}'.format('rep_threshold', self.rep_threshold, 'rep_min_len', self.rep_min_len, 'rep_max_len', self.rep_max_len)