Skip to content

Commit

Permalink
merging mass and bart
Browse files Browse the repository at this point in the history
  • Loading branch information
Mickus Timothee committed Sep 21, 2023
1 parent 21d21e1 commit 9f3aab8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 118 deletions.
6 changes: 3 additions & 3 deletions onmt/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ def test_transform_register(self):
"sentencepiece",
"bpe",
"onmt_tokenize",
"bart",
"ae_noise",
"switchout",
"tokendrop",
"tokenmask",
]
get_transforms_cls(builtin_transform)

def test_vocab_required_transform(self):
transforms_cls = get_transforms_cls(["bart", "switchout"])
transforms_cls = get_transforms_cls(["ae_noise", "switchout"])
opt = Namespace(seed=-1, switchout_temperature=1.0)
# transforms that require vocab will not create if not provide vocab
transforms = make_transforms(opt, transforms_cls, vocabs=None, task=None)
self.assertEqual(len(transforms), 0)
with self.assertRaises(ValueError):
transforms_cls["switchout"](opt).warm_up(vocabs=None)
transforms_cls["bart"](opt).warm_up(vocabs=None)
transforms_cls["ae_noise"](opt).warm_up(vocabs=None)

def test_transform_specials(self):
transforms_cls = get_transforms_cls(["prefix"])
Expand Down
139 changes: 24 additions & 115 deletions onmt/transforms/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,11 @@ def __repr__(self):
return '{}({})'.format(cls_name, cls_args)


@register_transform(name='bart')
class BARTNoiseTransform(Transform):
@register_transform(name='ae_noise')
class NoiseTransform(Transform):
def __init__(self, opts):
super().__init__(opts)
self.denoising_objective = opts.denoising_objective

@classmethod
def get_specials(cls, opts):
Expand All @@ -338,8 +339,8 @@ def _set_seed(self, seed):

@classmethod
def add_options(cls, parser):
"""Avalilable options relate to BART."""
group = parser.add_argument_group("Transform/BART")
"""Avalilable options relate to BART/MASS denoising."""
group = parser.add_argument_group("Transform/Denoising AE")
group.add(
"--permute_sent_ratio",
"-permute_sent_ratio",
Expand Down Expand Up @@ -394,6 +395,13 @@ def add_options(cls, parser):
choices=[-1, 0, 1],
help="When masking N tokens, replace with 0, 1, or N tokens. (use -1 for N)",
)
group.add(
"--denoising_objective",
type=str,
default='bart',
choices=['bart', 'mass'],
help='choose between BART-style or MASS-style denoising objectives'
)

@classmethod
def require_vocab(cls):
Expand Down Expand Up @@ -424,124 +432,17 @@ def warm_up(self, vocabs):
is_joiner=is_joiner,
)

def apply(self, example, is_train=False, stats=None, **kwargs):
def apply_bart(self, example, is_train=False, stats=None, **kwargs):
"""Apply BART noise to src side tokens."""
if is_train:
src = self.bart_noise.apply(example['src'])
example['src'] = src
return example

def _repr_args(self):
"""Return str represent key arguments for BART."""
return repr(self.bart_noise)


@register_transform(name='mass')
class MASSNoiseTransform(Transform):
def __init__(self, opts):
super().__init__(opts)

@classmethod
def get_specials(cls, opts):
# FIXME: If a different mask token is used, then it is up to you to add it to specials
return ({DefaultTokens.MASK}, set())

def _set_seed(self, seed):
"""set seed to ensure reproducibility."""
BARTNoising.set_random_seed(seed)

@classmethod
def add_options(cls, parser):
"""Avalilable options relate to BART."""
group = parser.add_argument_group("Transform/MASS")
group.add(
"--permute_sent_ratio",
"-permute_sent_ratio",
type=float,
default=0.0,
help="Permute this proportion of sentences "
"(boundaries defined by {}) in all inputs.".format(DefaultTokens.SENT_FULL_STOPS),
)
group.add("--rotate_ratio", "-rotate_ratio", type=float, default=0.0, help="Rotate this proportion of inputs.")
group.add(
"--insert_ratio",
"-insert_ratio",
type=float,
default=0.0,
help="Insert this percentage of additional random tokens.",
)
group.add(
"--random_ratio",
"-random_ratio",
type=float,
default=0.0,
help="Instead of using {}, use random token this often.".format(DefaultTokens.MASK),
)

group.add(
"--mask_ratio",
"-mask_ratio",
type=float,
default=0.0,
help="Fraction of words/subwords that will be masked.",
)
group.add(
"--mask_length",
"-mask_length",
type=str,
default="subword",
choices=["subword", "word", "span-poisson"],
help="Length of masking window to apply.",
)
group.add(
"--poisson_lambda",
"-poisson_lambda",
type=float,
default=3.0,
help="Lambda for Poisson distribution to sample span length if `-mask_length` set to span-poisson.",
)
group.add(
"--replace_length",
"-replace_length",
type=int,
default=-1,
choices=[-1, 0, 1],
help="When masking N tokens, replace with 0, 1, or N tokens. (use -1 for N)",
)

@classmethod
def require_vocab(cls):
"""Override this method to inform it need vocab to start."""
return True

def warm_up(self, vocabs):
super().warm_up(vocabs)

subword_type = self.opts.src_subword_type
if self.opts.mask_length == 'subword':
if subword_type == 'none':
raise ValueError(
f'src_subword_type={subword_type} incompatible with ' f'mask_length={self.opts.mask_length}!'
)
is_joiner = (subword_type == 'bpe') if subword_type != 'none' else None
self.mass_noise = BARTNoising(
self.vocabs['src'].itos,
mask_tok=DefaultTokens.MASK,
mask_ratio=self.opts.mask_ratio,
insert_ratio=self.opts.insert_ratio,
permute_sent_ratio=self.opts.permute_sent_ratio,
poisson_lambda=self.opts.poisson_lambda,
replace_length=self.opts.replace_length,
rotate_ratio=self.opts.rotate_ratio,
mask_length=self.opts.mask_length,
random_ratio=self.opts.random_ratio,
is_joiner=is_joiner,
)

def apply(self, example, is_train=False, stats=None, **kwargs):
def apply_mass(self, example, is_train=False, stats=None, **kwargs):
"""Apply BART noise to src side tokens, then complete as MASS scheme."""
if is_train:
masked = self.mass_noise.apply(example['src'])
masked = self.bart_noise.apply(example['src'])
complement_masked = [
DefaultTokens.MASK if masked_item != DefaultTokens.MASK
else source_item
Expand All @@ -559,6 +460,14 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
}
return example

def apply(self, example, is_train=False, stats=None, **kwargs):
if self.denoising_objective == 'bart':
return self.apply_bart(example, is_train=False, stats=None, **kwargs)
elif self.denoising_objective == 'mass':
return self.apply_mass(example, is_train=False, stats=None, **kwargs)
else:
raise NotImplementedError('Unknown denoising objective.')

def _repr_args(self):
"""Return str represent key arguments for BART."""
return repr(self.mass_noise)
return repr(self.bart_noise)

0 comments on commit 9f3aab8

Please sign in to comment.