diff --git a/onmt/transforms/filtering.py b/onmt/transforms/filtering.py index 57aeccb4..f0877741 100644 --- a/onmt/transforms/filtering.py +++ b/onmt/transforms/filtering.py @@ -48,3 +48,39 @@ def apply(self, example, is_train=False, stats=None, **kwargs): def _repr_args(self): """Return str represent key arguments for class.""" return '{}={}, {}={}'.format('src_seq_length', self.src_seq_length, 'tgt_seq_length', self.tgt_seq_length) + +# Filters inspired by OpusFilter https://github.com/Helsinki-NLP/OpusFilter/blob/aca40bd064d9b087c5216de0568d7fb91a31d142/opusfilter/filters.py + +@register_transform(name='filterwordratio') +class WordRatioFilter(Transform): + """Filter out sentence based on word length ratio""" + + def __init__(self, opts): + super().__init__(opts) + + @classmethod + def add_options(cls, parser): + """Avalilable options relate to this Transform.""" + group = parser.add_argument_group("Transform/Filter") + group.add("--word_ratio_threshold", "-word_ratio_threshold", type=int, default=3, help="Threshold for discarding sentences based on word ratio.") + + def _parse_opts(self): + self.word_ratio_threshold = self.opts.word_ratio_threshold + + def apply(self, example, **kwargs): + """Return None if too long else return as is.""" + src_len = len(example['src'].split()) + tgt_len = len(example['tgt'].split()) + lengths = sorted([src_len, tgt_len]) + if lengths[0] == 0: + return None + else: + ratio = lengths[-1] / lengths[0] + if ratio < self.word_ratio_threshold: + return example + else: + return None + + def _repr_args(self): + """Return str represent key arguments for class.""" + return '{}={}'.format('word_ratio_threshold', self.word_ratio_threshold) \ No newline at end of file