Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/filtering #22

Merged
merged 38 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
da61ac6
Created filtering file and moved filtertoolong there
onadegibert Sep 20, 2023
cecd25a
Also moved FilterTooLongStats
onadegibert Sep 20, 2023
546135f
added word ratio filter
onadegibert Sep 20, 2023
0b86280
modified description
onadegibert Sep 20, 2023
47f0bce
modified word length ratio to take it from opusfilter
onadegibert Sep 22, 2023
a111c80
added word ratio unit as option
onadegibert Sep 22, 2023
a1ac834
added lengthfilter and wordlengthratio both with and without opusfilt…
onadegibert Sep 22, 2023
4e11078
added transform name to error for easier debugging
onadegibert Sep 26, 2023
429f50a
added implementations with opusfilter
onadegibert Sep 26, 2023
e1b7c64
removed implementations with opusfilter
onadegibert Sep 26, 2023
02ac61d
added repetition filter
onadegibert Sep 26, 2023
c2cb8ba
removed unnecessary prints
onadegibert Sep 26, 2023
301425c
added terminal punctuation filter
onadegibert Sep 26, 2023
0559401
added nonzeronumerals
onadegibert Sep 26, 2023
2775324
fixed flake8 typos
onadegibert Sep 26, 2023
3637759
fixed regex import
onadegibert Sep 26, 2023
4c31462
Created filtering file and moved filtertoolong there
onadegibert Sep 20, 2023
53c4326
Also moved FilterTooLongStats
onadegibert Sep 20, 2023
ced0ac9
added word ratio filter
onadegibert Sep 20, 2023
92cb7eb
modified description
onadegibert Sep 20, 2023
6f66bb4
modified word length ratio to take it from opusfilter
onadegibert Sep 22, 2023
3d2cd3a
added word ratio unit as option
onadegibert Sep 22, 2023
7f8c75e
added lengthfilter and wordlengthratio both with and without opusfilt…
onadegibert Sep 22, 2023
96cd2db
added transform name to error for easier debugging
onadegibert Sep 26, 2023
09b73d7
added implementations with opusfilter
onadegibert Sep 26, 2023
e61b85f
removed implementations with opusfilter
onadegibert Sep 26, 2023
4444819
added repetition filter
onadegibert Sep 26, 2023
7b79a77
removed unnecessary prints
onadegibert Sep 26, 2023
0f07953
added terminal punctuation filter
onadegibert Sep 26, 2023
b30380f
added nonzeronumerals
onadegibert Sep 26, 2023
da89ca2
fixed flake8 typos
onadegibert Sep 26, 2023
16b4fbb
fixed regex import
onadegibert Sep 26, 2023
1ccfdc8
modified imports
onadegibert Sep 26, 2023
0de386e
Merge branch 'feat/filtering' of https://github.com/Helsinki-NLP/mamm…
onadegibert Sep 26, 2023
46eebb2
deleted old files
onadegibert Sep 26, 2023
4a77d83
fixed flake8 typos
onadegibert Sep 26, 2023
178dff9
fixed typo
onadegibert Sep 26, 2023
d37ed98
changed data to tasks
onadegibert Sep 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mammoth/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_transforms_cls(transform_names):
transforms_cls = {}
for name in transform_names:
if name not in AVAILABLE_TRANSFORMS:
raise ValueError("specified tranform not supported!")
raise ValueError("Specified transform not supported!", name)
transforms_cls[name] = AVAILABLE_TRANSFORMS[name]
return transforms_cls

Expand Down
219 changes: 219 additions & 0 deletions mammoth/transforms/filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
from mammoth.transforms import register_transform
from .transform import Transform, ObservableStats
import re
import math
import itertools
import string
import difflib


class FilterTooLongStats(ObservableStats):
"""Runing statistics for FilterTooLongTransform."""

__slots__ = ["filtered"]

def __init__(self):
self.filtered = 1

def update(self, other: "FilterTooLongStats"):
self.filtered += other.filtered


@register_transform(name='filtertoolong')
class FilterTooLongTransform(Transform):
"""Filter out sentence that are too long."""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Available options relating to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--src_seq_length", "-src_seq_length", type=int, default=200, help="Maximum source sequence length.")
group.add("--tgt_seq_length", "-tgt_seq_length", type=int, default=200, help="Maximum target sequence length.")

def _parse_opts(self):
self.src_seq_length = self.opts.src_seq_length
self.tgt_seq_length = self.opts.tgt_seq_length

def apply(self, example, is_train=False, stats=None, **kwargs):
"""Return None if too long else return as is."""
src_len = len(example['src'])
tgt_len = len(example['tgt'])
if src_len == 0 or tgt_len == 0:
# also filter empty strings
return None
if src_len > self.src_seq_length or tgt_len > self.tgt_seq_length:
if stats is not None:
stats.update(FilterTooLongStats())
return None
else:
return example

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}, {}={}'.format('src_seq_length', self.src_seq_length, 'tgt_seq_length', self.tgt_seq_length)


# Filters inspired by OpusFilter
# https://github.com/Helsinki-NLP/OpusFilter/blob/aca40bd064d9b087c5216de0568d7fb91a31d142/opusfilter/filters.py


@register_transform(name='filterwordratio')
class FilterWordRatio(Transform):
"""Filter out sentence based on word length ratio"""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Available options relating to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--word_ratio_threshold", "-word_ratio_threshold", type=int, default=3,
help="Threshold for discarding sentences based on word ratio.")

def _parse_opts(self):
self.word_ratio_threshold = self.opts.word_ratio_threshold

def apply(self, example, **kwargs):
"""Return None if too long else return as is."""
src_len = len(example['src'])
tgt_len = len(example['tgt'])
lengths = sorted([src_len, tgt_len])
if lengths[0] == 0:
return None
else:
ratio = lengths[-1] / lengths[0]
if ratio < self.word_ratio_threshold:
return example
else:
return None

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}'.format('word_ratio_threshold', self.word_ratio_threshold)


@register_transform(name='filterrepetitions')
class FilterRepetitions(Transform):
"""Filter segments with repeated content. Useful e.g. for filtering data generated by a low-quality NMT model."""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Available options relating to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--rep_threshold", "-rep_threshold", type=int, default=2,
help="Number of times the substring is repeated.")
group.add("--rep_min_len", "-rep_min_len", type=int, default=3,
help="Minimum length of the repeated pattern.")
group.add("--rep_max_len", "-rep_max_len", type=int, default=100,
help="Maximum length of the repeated pattern.")

def _parse_opts(self):
self.rep_threshold = self.opts.rep_threshold
self.rep_min_len = self.opts.rep_min_len
self.rep_max_len = self.opts.rep_max_len

def apply(self, example, **kwargs):
"""Return None if the repeated pattern appears more than n-threshold times."""
# compiled regexp for finding repetitions
rstring = f'(\\S.{{{self.rep_min_len-1},{self.rep_max_len}}}?)(?: *\\1){{{self.rep_threshold},}}'
regex = re.compile(rstring)
reps = []
for segment in example['src'], example['tgt']:
match = regex.search(' '.join(segment))
if match:
full = match.group(0)
repeated = match.group(1)
rep = full.count(repeated) - 1
else:
rep = 0
reps.append(rep)
if max(reps) > self.rep_threshold:
return None
else:
return example

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}, {}={}, {}={}'.format('rep_threshold', self.rep_threshold,
'rep_min_len', self.rep_min_len, 'rep_max_len', self.rep_max_len)


@register_transform(name='filterterminalpunct')
class FilterTerminalPunctuation(Transform):
"""Filter segments with respect to the co-occurrence of terminal punctuation marks"""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Available options relating to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--punct_threshold", "-punct_threshold", type=int, default=-2,
help="Minimum penalty score for discarding sentences based on their terminal punctuation signs")

def _parse_opts(self):
self.punct_threshold = self.opts.punct_threshold

def apply(self, example, **kwargs):
"""Return None if the penalty is smaller than the threshold."""
src = ' '.join(example['src'])
tgt = ' '.join(example['tgt'])
spun = len([c for c in src if c in ['.', '?', '!', '…']])
tpun = len([c for c in tgt if c in ['.', '?', '!', '…']])
score = abs(spun - tpun)
if spun > 1:
score += spun - 1
if tpun > 1:
score += tpun - 1
score = -math.log(score + 1)
if score >= self.punct_threshold:
return example
else:
return None

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}'.format('punct_threshold', self.punct_threshold)


@register_transform(name='filternonzeronumerals')
class FilterNonZeroNumerals(Transform):
"""Filter segments based on a similarity measure of numerals between the segments with zeros removed"""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Available options relating to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--nonzero_threshold", "-nonzero_threshold", type=float, default=0.5,
help="Threshold for discarding sentences based on numerals between the segments with zeros removed")

def _parse_opts(self):
self.nonzero_threshold = self.opts.nonzero_threshold

def apply(self, example, **kwargs):
"""Return None if the penalty is smaller than the threshold."""
src = ' '.join(example['src'])
tgt = ' '.join(example['tgt'])
nums = [[int(c) for c in sent if c in string.digits and c != '0'] for sent in [src, tgt]]
for num1, num2 in itertools.combinations(nums, 2):
seq = difflib.SequenceMatcher(None, num1, num2)
ratio = seq.ratio()
if ratio >= self.nonzero_threshold:
return example
else:
return None

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}'.format('nonzero_threshold', self.nonzero_threshold)
51 changes: 1 addition & 50 deletions mammoth/transforms/misc.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,6 @@
from mammoth.utils.logging import logger
from mammoth.transforms import register_transform
from .transform import Transform, ObservableStats


class FilterTooLongStats(ObservableStats):
"""Runing statistics for FilterTooLongTransform."""

__slots__ = ["filtered"]

def __init__(self):
self.filtered = 1

def update(self, other: "FilterTooLongStats"):
self.filtered += other.filtered


@register_transform(name='filtertoolong')
class FilterTooLongTransform(Transform):
"""Filter out sentence that are too long."""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
"""Avalilable options relate to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--src_seq_length", "-src_seq_length", type=int, default=200, help="Maximum source sequence length.")
group.add("--tgt_seq_length", "-tgt_seq_length", type=int, default=200, help="Maximum target sequence length.")

def _parse_opts(self):
self.src_seq_length = self.opts.src_seq_length
self.tgt_seq_length = self.opts.tgt_seq_length

def apply(self, example, is_train=False, stats=None, **kwargs):
"""Return None if too long else return as is."""
src_len = len(example['src'])
tgt_len = len(example['tgt'])
if src_len == 0 or tgt_len == 0:
# also filter empty strings
return None
if src_len > self.src_seq_length or tgt_len > self.tgt_seq_length:
if stats is not None:
stats.update(FilterTooLongStats())
return None
else:
return example

def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}, {}={}'.format('src_seq_length', self.src_seq_length, 'tgt_seq_length', self.tgt_seq_length)
from .transform import Transform


@register_transform(name='prefix')
Expand Down