Skip to content

Commit

Permalink
fixed flake8 typos
Browse files Browse the repository at this point in the history
  • Loading branch information
onadegibert committed Sep 26, 2023
1 parent 0559401 commit 2775324
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 12 deletions.
2 changes: 1 addition & 1 deletion onmt/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_transforms_cls(transform_names):
transforms_cls = {}
for name in transform_names:
if name not in AVAILABLE_TRANSFORMS:
raise ValueError("Specified transform not supported!" , name)
raise ValueError("Specified transform not supported!", name)
transforms_cls[name] = AVAILABLE_TRANSFORMS[name]
return transforms_cls

Expand Down
37 changes: 26 additions & 11 deletions onmt/transforms/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import string
import difflib


class FilterTooLongStats(ObservableStats):
"""Runing statistics for FilterTooLongTransform."""

Expand All @@ -17,6 +18,7 @@ def __init__(self):
def update(self, other: "FilterTooLongStats"):
self.filtered += other.filtered


@register_transform(name='filtertoolong')
class FilterTooLongTransform(Transform):
"""Filter out sentence that are too long."""
Expand Down Expand Up @@ -53,7 +55,10 @@ def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}, {}={}'.format('src_seq_length', self.src_seq_length, 'tgt_seq_length', self.tgt_seq_length)

# Filters inspired by OpusFilter https://github.com/Helsinki-NLP/OpusFilter/blob/aca40bd064d9b087c5216de0568d7fb91a31d142/opusfilter/filters.py

# Filters inspired by OpusFilter
# https://github.com/Helsinki-NLP/OpusFilter/blob/aca40bd064d9b087c5216de0568d7fb91a31d142/opusfilter/filters.py


@register_transform(name='filterwordratio')
class FilterWordRatio(Transform):
Expand All @@ -66,7 +71,8 @@ def __init__(self, opts):
def add_options(cls, parser):
"""Avalilable options relate to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--word_ratio_threshold", "-word_ratio_threshold", type=int, default=3, help="Threshold for discarding sentences based on word ratio.")
group.add("--word_ratio_threshold", "-word_ratio_threshold", type=int, default=3,
help="Threshold for discarding sentences based on word ratio.")

def _parse_opts(self):
self.word_ratio_threshold = self.opts.word_ratio_threshold
Expand All @@ -89,6 +95,7 @@ def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}'.format('word_ratio_threshold', self.word_ratio_threshold)


@register_transform(name='filterrepetitions')
class FilterRepetitions(Transform):
"""Filter segments with repeated content. Useful e.g. for filtering data generated by a low-quality NMT model."""
Expand All @@ -100,9 +107,12 @@ def __init__(self, opts):
def add_options(cls, parser):
"""Avalilable options relate to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--rep_threshold", "-rep_threshold", type=int, default=2, help="Number of times the substring is repeated.")
group.add("--rep_min_len", "-rep_min_len", type=int, default=3, help="Minimum length of the repeated pattern.")
group.add("--rep_max_len", "-rep_max_len", type=int, default=100, help="Maximum length of the repeated pattern.")
group.add("--rep_threshold", "-rep_threshold", type=int, default=2,
help="Number of times the substring is repeated.")
group.add("--rep_min_len", "-rep_min_len", type=int, default=3,
help="Minimum length of the repeated pattern.")
group.add("--rep_max_len", "-rep_max_len", type=int, default=100,
help="Maximum length of the repeated pattern.")

def _parse_opts(self):
self.rep_threshold = self.opts.rep_threshold
Expand Down Expand Up @@ -131,7 +141,9 @@ def apply(self, example, **kwargs):

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}, {}={}, {}={}'.format('rep_threshold', self.rep_threshold, 'rep_min_len', self.rep_min_len, 'rep_max_len', self.rep_max_len)
return '{}={}, {}={}, {}={}'.format('rep_threshold', self.rep_threshold,
'rep_min_len', self.rep_min_len, 'rep_max_len', self.rep_max_len)


@register_transform(name='filterterminalpunct')
class FilterTerminalPunctuation(Transform):
Expand All @@ -144,7 +156,8 @@ def __init__(self, opts):
def add_options(cls, parser):
"""Avalilable options relate to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--punct_threshold", "-punct_threshold", type=int, default=-2, help="Minimum penalty score for discarding sentences based on their terminal punctuation signs")
group.add("--punct_threshold", "-punct_threshold", type=int, default=-2,
help="Minimum penalty score for discarding sentences based on their terminal punctuation signs")

def _parse_opts(self):
self.punct_threshold = self.opts.punct_threshold
Expand All @@ -170,6 +183,7 @@ def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}'.format('punct_threshold', self.punct_threshold)


@register_transform(name='filternonzeronumerals')
class FilterNonZeroNumerals(Transform):
"""Filter segments based on a similarity measure of numerals between the segments with zeros removed"""
Expand All @@ -181,7 +195,8 @@ def __init__(self, opts):
def add_options(cls, parser):
"""Avalilable options relate to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--nonzero_threshold", "-nonzero_threshold", type=float, default=0.5, help="Threshold for discarding sentences based on a similarity measure of numerals between the segments with zeros removed")
group.add("--nonzero_threshold", "-nonzero_threshold", type=float, default=0.5,
help="Threshold for discarding sentences based on numerals between the segments with zeros removed")

def _parse_opts(self):
self.nonzero_threshold = self.opts.nonzero_threshold
Expand All @@ -190,15 +205,15 @@ def apply(self, example, **kwargs):
"""Return None if the penalty is smaller than the threshold."""
src = ' '.join(example['src'])
tgt = ' '.join(example['tgt'])
nums = [[int(c) for c in sent if c in string.digits and c != '0'] for sent in [src,tgt]]
nums = [[int(c) for c in sent if c in string.digits and c != '0'] for sent in [src, tgt]]
for num1, num2 in itertools.combinations(nums, 2):
seq = difflib.SequenceMatcher(None, num1, num2)
ratio = seq.ratio()
if ratio >= self.nonzero_threshold:
return example
else:
return None

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}'.format('nonzero_threshold', self.nonzero_threshold)
return '{}={}'.format('nonzero_threshold', self.nonzero_threshold)
1 change: 1 addition & 0 deletions onmt/transforms/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from onmt.transforms import register_transform
from .transform import Transform


@register_transform(name='prefix')
class PrefixTransform(Transform):
"""Add Prefix to src (& tgt) sentence."""
Expand Down

0 comments on commit 2775324

Please sign in to comment.