diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py index 234e6ebb..0cffb0c8 100644 --- a/onmt/tests/test_transform.py +++ b/onmt/tests/test_transform.py @@ -22,7 +22,7 @@ def test_transform_register(self): "sentencepiece", "bpe", "onmt_tokenize", - "bart", + "ae_noise", "switchout", "tokendrop", "tokenmask", @@ -30,14 +30,14 @@ def test_transform_register(self): get_transforms_cls(builtin_transform) def test_vocab_required_transform(self): - transforms_cls = get_transforms_cls(["bart", "switchout"]) + transforms_cls = get_transforms_cls(["ae_noise", "switchout"]) opt = Namespace(seed=-1, switchout_temperature=1.0) # transforms that require vocab will not create if not provide vocab transforms = make_transforms(opt, transforms_cls, vocabs=None, task=None) self.assertEqual(len(transforms), 0) with self.assertRaises(ValueError): transforms_cls["switchout"](opt).warm_up(vocabs=None) - transforms_cls["bart"](opt).warm_up(vocabs=None) + transforms_cls["ae_noise"](opt).warm_up(vocabs=None) def test_transform_specials(self): transforms_cls = get_transforms_cls(["prefix"]) diff --git a/onmt/transforms/bart.py b/onmt/transforms/bart.py index 94a6a79d..aa46b062 100644 --- a/onmt/transforms/bart.py +++ b/onmt/transforms/bart.py @@ -322,10 +322,11 @@ def __repr__(self): return '{}({})'.format(cls_name, cls_args) -@register_transform(name='bart') -class BARTNoiseTransform(Transform): +@register_transform(name='ae_noise') +class NoiseTransform(Transform): def __init__(self, opts): super().__init__(opts) + self.denoising_objective = opts.denoising_objective @classmethod def get_specials(cls, opts): @@ -338,8 +339,8 @@ def _set_seed(self, seed): @classmethod def add_options(cls, parser): - """Avalilable options relate to BART.""" - group = parser.add_argument_group("Transform/BART") + """Avalilable options relate to BART/MASS denoising.""" + group = parser.add_argument_group("Transform/Denoising AE") group.add( "--permute_sent_ratio", "-permute_sent_ratio", @@ -394,6 +395,13 @@ def add_options(cls, parser): choices=[-1, 0, 1], help="When masking N tokens, replace with 0, 1, or N tokens. (use -1 for N)", ) + group.add( + "--denoising_objective", + type=str, + default='bart', + choices=['bart', 'mass'], + help='choose between BART-style or MASS-style denoising objectives' + ) @classmethod def require_vocab(cls): @@ -424,124 +432,17 @@ def warm_up(self, vocabs): is_joiner=is_joiner, ) - def apply(self, example, is_train=False, stats=None, **kwargs): + def apply_bart(self, example, is_train=False, stats=None, **kwargs): """Apply BART noise to src side tokens.""" if is_train: src = self.bart_noise.apply(example['src']) example['src'] = src return example - def _repr_args(self): - """Return str represent key arguments for BART.""" - return repr(self.bart_noise) - - -@register_transform(name='mass') -class MASSNoiseTransform(Transform): - def __init__(self, opts): - super().__init__(opts) - - @classmethod - def get_specials(cls, opts): - # FIXME: If a different mask token is used, then it is up to you to add it to specials - return ({DefaultTokens.MASK}, set()) - - def _set_seed(self, seed): - """set seed to ensure reproducibility.""" - BARTNoising.set_random_seed(seed) - - @classmethod - def add_options(cls, parser): - """Avalilable options relate to BART.""" - group = parser.add_argument_group("Transform/MASS") - group.add( - "--permute_sent_ratio", - "-permute_sent_ratio", - type=float, - default=0.0, - help="Permute this proportion of sentences " - "(boundaries defined by {}) in all inputs.".format(DefaultTokens.SENT_FULL_STOPS), - ) - group.add("--rotate_ratio", "-rotate_ratio", type=float, default=0.0, help="Rotate this proportion of inputs.") - group.add( - "--insert_ratio", - "-insert_ratio", - type=float, - default=0.0, - help="Insert this percentage of additional random tokens.", - ) - group.add( - "--random_ratio", - "-random_ratio", - type=float, - default=0.0, - help="Instead of using {}, use random token this often.".format(DefaultTokens.MASK), - ) - - group.add( - "--mask_ratio", - "-mask_ratio", - type=float, - default=0.0, - help="Fraction of words/subwords that will be masked.", - ) - group.add( - "--mask_length", - "-mask_length", - type=str, - default="subword", - choices=["subword", "word", "span-poisson"], - help="Length of masking window to apply.", - ) - group.add( - "--poisson_lambda", - "-poisson_lambda", - type=float, - default=3.0, - help="Lambda for Poisson distribution to sample span length if `-mask_length` set to span-poisson.", - ) - group.add( - "--replace_length", - "-replace_length", - type=int, - default=-1, - choices=[-1, 0, 1], - help="When masking N tokens, replace with 0, 1, or N tokens. (use -1 for N)", - ) - - @classmethod - def require_vocab(cls): - """Override this method to inform it need vocab to start.""" - return True - - def warm_up(self, vocabs): - super().warm_up(vocabs) - - subword_type = self.opts.src_subword_type - if self.opts.mask_length == 'subword': - if subword_type == 'none': - raise ValueError( - f'src_subword_type={subword_type} incompatible with ' f'mask_length={self.opts.mask_length}!' - ) - is_joiner = (subword_type == 'bpe') if subword_type != 'none' else None - self.mass_noise = BARTNoising( - self.vocabs['src'].itos, - mask_tok=DefaultTokens.MASK, - mask_ratio=self.opts.mask_ratio, - insert_ratio=self.opts.insert_ratio, - permute_sent_ratio=self.opts.permute_sent_ratio, - poisson_lambda=self.opts.poisson_lambda, - replace_length=self.opts.replace_length, - rotate_ratio=self.opts.rotate_ratio, - mask_length=self.opts.mask_length, - random_ratio=self.opts.random_ratio, - is_joiner=is_joiner, - ) - - def apply(self, example, is_train=False, stats=None, **kwargs): + def apply_mass(self, example, is_train=False, stats=None, **kwargs): """Apply BART noise to src side tokens, then complete as MASS scheme.""" if is_train: - masked = self.mass_noise.apply(example['src']) + masked = self.bart_noise.apply(example['src']) complement_masked = [ DefaultTokens.MASK if masked_item != DefaultTokens.MASK else source_item @@ -559,6 +460,14 @@ def apply(self, example, is_train=False, stats=None, **kwargs): } return example + def apply(self, example, is_train=False, stats=None, **kwargs): + if self.denoising_objective == 'bart': + return self.apply_bart(example, is_train=False, stats=None, **kwargs) + elif self.denoising_objective == 'mass': + return self.apply_mass(example, is_train=False, stats=None, **kwargs) + else: + raise NotImplementedError('Unknown denoising objective.') + def _repr_args(self): """Return str represent key arguments for BART.""" - return repr(self.mass_noise) + return repr(self.bart_noise)