Skip to content

Commit

Permalink
added repetition filter
Browse files Browse the repository at this point in the history
  • Loading branch information
onadegibert committed Sep 26, 2023
1 parent e1b7c64 commit 02ac61d
Showing 1 changed file with 47 additions and 1 deletion.
48 changes: 47 additions & 1 deletion onmt/transforms/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,50 @@ def apply(self, example, **kwargs):

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}'.format('word_ratio_threshold', self.word_ratio_threshold)
return '{}={}'.format('word_ratio_threshold', self.word_ratio_threshold)

@register_transform(name='filterrepetitions')
class FilterRepetitions(Transform):
"""Filter segments with repeated content. Useful e.g. for filtering data generated by a low-quality NMT model."""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Avalilable options relate to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--rep_threshold", "-rep_threshold", type=int, default=2, help="Number of times the substring is repeated.")
group.add("--rep_min_len", "-rep_min_len", type=int, default=3, help="Minimum length of the repeated pattern.")
group.add("--rep_max_len", "-rep_max_len", type=int, default=100, help="Maximum length of the repeated pattern.")

def _parse_opts(self):
self.rep_threshold = self.opts.rep_threshold
self.rep_min_len = self.opts.rep_min_len
self.rep_max_len = self.opts.rep_max_len

def apply(self, example, **kwargs):
"""Return None if the repeated pattern appears more than n-threshold times."""
# compiled regexp for finding repetitions
rstring = f'(\\S.{{{self.rep_min_len-1},{self.rep_max_len}}}?)(?: *\\1){{{self.rep_threshold},}}'
regex = re.compile(rstring)
reps = []
for segment in example['src'], example['tgt']:
match = regex.search(' '.join(segment))
print(match)
if match:
full = match.group(0)
repeated = match.group(1)
rep = full.count(repeated) - 1
else:
rep = 0
reps.append(rep)
print(reps)
if max(reps) > self.rep_threshold:
return None
else:
return example

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}, {}={}, {}={}'.format('rep_threshold', self.rep_threshold, 'rep_min_len', self.rep_min_len, 'rep_max_len', self.rep_max_len)

0 comments on commit 02ac61d

Please sign in to comment.