Skip to content

Commit

Permalink
added lengthfilter and wordlengthratio both with and without opusfilt…
Browse files Browse the repository at this point in the history
…er for speed testing
  • Loading branch information
onadegibert committed Sep 22, 2023
1 parent a111c80 commit a1ac834
Showing 1 changed file with 70 additions and 1 deletion.
71 changes: 70 additions & 1 deletion onmt/transforms/filtering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from onmt.transforms import register_transform
from .transform import Transform, ObservableStats
from opusfilter.filters import LengthRatioFilter
from opusfilter.filters import LengthRatioFilter, LengthFilter

class FilterTooLongStats(ObservableStats):
"""Runing statistics for FilterTooLongTransform."""
Expand Down Expand Up @@ -49,11 +49,80 @@ def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}, {}={}'.format('src_seq_length', self.src_seq_length, 'tgt_seq_length', self.tgt_seq_length)


@register_transform(name='filtertoolong-opusfilter')
class FilterTooLongTransformOP(Transform):
"""Filter out sentence that are too long."""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Avalilable options relate to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--src_seq_length", "-src_seq_length", type=int, default=200, help="Maximum source sequence length.")
group.add("--tgt_seq_length", "-tgt_seq_length", type=int, default=200, help="Maximum target sequence length.")

def _parse_opts(self):
self.src_seq_length = self.opts.src_seq_length
self.tgt_seq_length = self.opts.tgt_seq_length

def apply(self, example, is_train=False, stats=None, **kwargs):
"""Return None if too long else return as is."""
lf_src = LengthFilter(max_length=self.src_seq_length)
length_src = lf.get_length(example['src'], 0)
lf_tgt = LengthFilter(max_length=self.src_seq_length)
length_tgt = lf.get_length(example['tgt'], 0)
if lf_src.accept([length_src]) and lf_tgt.accept([length_tgt]):
return example
else:
return else

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}, {}={}'.format('src_seq_length', self.src_seq_length, 'tgt_seq_length', self.tgt_seq_length)


# Filters inspired by OpusFilter https://github.com/Helsinki-NLP/OpusFilter/blob/aca40bd064d9b087c5216de0568d7fb91a31d142/opusfilter/filters.py

@register_transform(name='filterwordratio')
class WordRatioFilter(Transform):
"""Filter out sentence based on word length ratio"""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Avalilable options relate to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--word_ratio_threshold", "-word_ratio_threshold", type=int, default=3, help="Threshold for discarding sentences based on word ratio.")

def _parse_opts(self):
self.word_ratio_threshold = self.opts.word_ratio_threshold

def apply(self, example, **kwargs):
"""Return None if too long else return as is."""
src_len = len(example['src'].split())
tgt_len = len(example['tgt'].split())
lengths = sorted([src_len, tgt_len])
if lengths[0] == 0:
return None
else:
ratio = lengths[-1] / lengths[0]
if ratio < self.word_ratio_threshold:
return example
else:
return None

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}'.format('word_ratio_threshold', self.word_ratio_threshold)

@register_transform(name='filterwordratio-opusfilter')
class WordRatioFilterOP(Transform):
"""Filter out sentence based on word length ratio"""

def __init__(self, opts):
super().__init__(opts)
Expand Down

0 comments on commit a1ac834

Please sign in to comment.