Skip to content

Commit

Permalink
mass, commented
Browse files Browse the repository at this point in the history
  • Loading branch information
Mickus Timothee committed Sep 25, 2023
1 parent 4a93e0e commit 904a706
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/source/config_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Generate translation configs for zero-shot directions.
#### `transforms` and `ae_transforms`

A list of transforms, for translation tasks and autoencoder tasks, respectively.
Use this to apply subword segmentation, e.g. using `sentencepiece`, and `ae_noise` noise for autoencoder.
Use this to apply subword segmentation, e.g. using `sentencepiece`, and `denoising` noise for autoencoder.
Both of these may change the sequence length, necessitating a `filtertoolong` transform.

#### `enc_sharing_groups` and `dec_sharing_groups`
Expand Down
2 changes: 1 addition & 1 deletion examples/config_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ config_config:
ae_transforms:
- sentencepiece
- filtertoolong
- ae_noise
- denoising
enc_sharing_groups:
- GROUP
- FULL
Expand Down
12 changes: 9 additions & 3 deletions onmt/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ def test_transform_register(self):
"sentencepiece",
"bpe",
"onmt_tokenize",
"ae_noise",
"denoising",
"switchout",
"tokendrop",
"tokenmask",
]
get_transforms_cls(builtin_transform)

def test_vocab_required_transform(self):
transforms_cls = get_transforms_cls(["ae_noise", "switchout"])
transforms_cls = get_transforms_cls(["denoising", "switchout"])
opt = Namespace(seed=-1, switchout_temperature=1.0)
# transforms that require vocab will not create if not provide vocab
transforms = make_transforms(opt, transforms_cls, vocabs=None, task=None)
self.assertEqual(len(transforms), 0)
with self.assertRaises(ValueError):
transforms_cls["switchout"](opt).warm_up(vocabs=None)
transforms_cls["ae_noise"](opt).warm_up(vocabs=None)
transforms_cls["denoising"](opt).warm_up(vocabs=None)

def test_transform_specials(self):
transforms_cls = get_transforms_cls(["prefix"])
Expand Down Expand Up @@ -516,6 +516,12 @@ def test_span_infilling(self):
# print(f"Text Span Infilling: {infillied} / {tokens}")
# print(n_words, n_masked)

def test_vocab_required_transform(self):
transforms_cls = get_transforms_cls(["denoising"])
opt = Namespace(random_ratio=1, denoising_objective='mass')
with self.assertRaises(ValueError):
transforms = make_transforms(opt, transforms_cls, vocabs=None, task=None)


class TestFeaturesTransform(unittest.TestCase):
def test_inferfeats(self):
Expand Down
3 changes: 3 additions & 0 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,9 @@ def _gradient_accumulation_over_lang_pairs(
if src_lengths is not None:
report_stats.n_src_words += src_lengths.sum().item()

# tgt_outer corresponds to the target-side input. The expected
# decoder output will be read directly from the batch:
# cf. `onmt.utils.loss.CommonLossCompute._make_shard_state`
tgt_outer = batch.tgt

bptt = False
Expand Down
6 changes: 4 additions & 2 deletions onmt/transforms/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,11 @@ def __repr__(self):
return '{}({})'.format(cls_name, cls_args)


@register_transform(name='ae_noise')
@register_transform(name='denoising')
class NoiseTransform(Transform):
def __init__(self, opts):
if opts.random_ratio > 0 and opts.denoising_objective == 'mass':
raise ValueError('Random word replacement is incompatible with MASS')
super().__init__(opts)
self.denoising_objective = opts.denoising_objective

Expand Down Expand Up @@ -362,7 +364,7 @@ def add_options(cls, parser):
"-random_ratio",
type=float,
default=0.0,
help="Instead of using {}, use random token this often.".format(DefaultTokens.MASK),
help=f"Instead of using {DefaultTokens.MASK}, use random token this often. Incompatible with MASS",
)

group.add(
Expand Down
1 change: 1 addition & 0 deletions onmt/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def _make_shard_state(self, batch, output, range_, attns=None):
range_end = range_[1]
shard_state = {
"output": output,
# TODO: target here is likely unnecessary, as it now corresponds to target-side input
"target": batch.tgt[range_start:range_end, :, 0],
"labels": batch.labels[range_start:range_end, :, 0],
}
Expand Down
2 changes: 1 addition & 1 deletion onmt/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _get_all_transform(cls, opt):
if hasattr(opt, 'lambda_align') and opt.lambda_align > 0.0:
if not all_transforms.isdisjoint({'sentencepiece', 'bpe', 'onmt_tokenize'}):
raise ValueError('lambda_align is not compatible with on-the-fly tokenization.')
if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'ae_noise'}):
if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'denoising'}):
raise ValueError('lambda_align is not compatible yet with potential token deletion/addition.')
opt._all_transform = all_transforms

Expand Down

0 comments on commit 904a706

Please sign in to comment.