diff --git a/docs/source/config_config.md b/docs/source/config_config.md index 650439de..d64773b1 100644 --- a/docs/source/config_config.md +++ b/docs/source/config_config.md @@ -34,7 +34,7 @@ The meta-parameters under the `config_config` key: Path templates for source and target corpora, respectively. The path templates can contain the following variables that will be substituted by `config_config`: -- Directional corpus mode +- Directional corpus mode - `{src_lang}`: The source language of the task - `{tgt_lang}`: The target language of the task - `{lang_pair}`: `{src_lang}-{tgt_lang}` for convenience @@ -99,7 +99,7 @@ Generate translation configs for zero-shot directions. #### `transforms` and `ae_transforms` A list of transforms, for translation tasks and autoencoder tasks, respectively. -Use this to apply subword segmentation, e.g. using `sentencepiece`, and `bart` noise for autoencoder. +Use this to apply subword segmentation, e.g. using `sentencepiece`, and `denoising` noise for autoencoder. Both of these may change the sequence length, necessitating a `filtertoolong` transform. #### `enc_sharing_groups` and `dec_sharing_groups` diff --git a/examples/config_config.yaml b/examples/config_config.yaml index 47fe8ca0..6ff337b7 100644 --- a/examples/config_config.yaml +++ b/examples/config_config.yaml @@ -16,7 +16,7 @@ config_config: ae_transforms: - sentencepiece - filtertoolong - - bart + - denoising enc_sharing_groups: - GROUP - FULL @@ -27,7 +27,7 @@ config_config: translation_config_dir: config/translation.opus n_gpus_per_node: 4 n_nodes: 2 - + # Note that this specifies the groups manually instead of clustering groups: "en": "en" @@ -43,7 +43,7 @@ config_config: save_data: generated/opus.spm32k # vocabs serve two purposes: defines the vocab files, and gives the potential languages to consider -src_vocab: +src_vocab: "af": "/scratch/project_2005099/data/opus/prepare_opus_data_tc_out/opusTC.afr.32k.spm.vocab" "da": "/scratch/project_2005099/data/opus/prepare_opus_data_tc_out/opusTC.dan.32k.spm.vocab" "en": "/scratch/project_2005099/data/opus/prepare_opus_data_tc_out/opusTC.eng.32k.spm.vocab" diff --git a/onmt/inputters/dataset.py b/onmt/inputters/dataset.py index c0f1b2e6..a8044dec 100644 --- a/onmt/inputters/dataset.py +++ b/onmt/inputters/dataset.py @@ -18,12 +18,15 @@ class Batch(): src: tuple # of torch Tensors tgt: torch.Tensor + labels: torch.Tensor batch_size: int def to(self, device): self.src = (self.src[0].to(device), self.src[1].to(device)) if self.tgt is not None: self.tgt = self.tgt.to(device) + if self.labels is not None: + self.labels = self.labels.to(device) return self @@ -156,8 +159,16 @@ def collate_fn(self, examples): tgt_padidx = self.vocabs['tgt'][DefaultTokens.PAD] src_lengths = torch.tensor([ex['src'].numel() for ex in examples], device='cpu') src = (pad_sequence([ex['src'] for ex in examples], padding_value=src_padidx).unsqueeze(-1), src_lengths) - tgt = pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1) if has_tgt else None - batch = Batch(src, tgt, len(examples)) + if has_tgt: + tgt = pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1) + if 'labels' not in examples[0].keys(): + labels = tgt + else: + labels = pad_sequence([ex['labels'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1) + else: + tgt = None + labels = None + batch = Batch(src, tgt, labels, len(examples)) return batch diff --git a/onmt/tests/test_subword_marker.py b/onmt/tests/test_subword_marker.py index 8987cbc0..afa17fcf 100644 --- a/onmt/tests/test_subword_marker.py +++ b/onmt/tests/test_subword_marker.py @@ -1,6 +1,6 @@ import unittest -from onmt.transforms.bart import word_start_finder +from onmt.transforms.denoising import word_start_finder from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer from onmt.constants import SubwordMarker diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py index 234e6ebb..32e25f83 100644 --- a/onmt/tests/test_transform.py +++ b/onmt/tests/test_transform.py @@ -11,7 +11,7 @@ make_transforms, TransformPipe, ) -from onmt.transforms.bart import BARTNoising +from onmt.transforms.denoising import BARTNoising class TestTransform(unittest.TestCase): @@ -22,7 +22,7 @@ def test_transform_register(self): "sentencepiece", "bpe", "onmt_tokenize", - "bart", + "denoising", "switchout", "tokendrop", "tokenmask", @@ -30,14 +30,14 @@ def test_transform_register(self): get_transforms_cls(builtin_transform) def test_vocab_required_transform(self): - transforms_cls = get_transforms_cls(["bart", "switchout"]) + transforms_cls = get_transforms_cls(["denoising", "switchout"]) opt = Namespace(seed=-1, switchout_temperature=1.0) # transforms that require vocab will not create if not provide vocab transforms = make_transforms(opt, transforms_cls, vocabs=None, task=None) self.assertEqual(len(transforms), 0) with self.assertRaises(ValueError): transforms_cls["switchout"](opt).warm_up(vocabs=None) - transforms_cls["bart"](opt).warm_up(vocabs=None) + transforms_cls["denoising"](opt).warm_up(vocabs=None) def test_transform_specials(self): transforms_cls = get_transforms_cls(["prefix"]) @@ -516,6 +516,12 @@ def test_span_infilling(self): # print(f"Text Span Infilling: {infillied} / {tokens}") # print(n_words, n_masked) + def test_vocab_required_transform(self): + transforms_cls = get_transforms_cls(["denoising"]) + opt = Namespace(random_ratio=1, denoising_objective='mass') + with self.assertRaises(ValueError): + make_transforms(opt, transforms_cls, vocabs=None, task=None) + class TestFeaturesTransform(unittest.TestCase): def test_inferfeats(self): diff --git a/onmt/trainer.py b/onmt/trainer.py index 0d58d8ce..0b76a3a3 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -465,7 +465,7 @@ def _gradient_accumulation_over_lang_pairs( seen_comm_batches.add(comm_batch) if self.norm_method == "tokens": num_tokens = ( - batch.tgt[1:, :, 0].ne(self.train_loss_md[f'trainloss{metadata.tgt_lang}'].padding_idx).sum() + batch.labels[1:, :, 0].ne(self.train_loss_md[f'trainloss{metadata.tgt_lang}'].padding_idx).sum() ) normalization += num_tokens.item() else: @@ -485,6 +485,9 @@ def _gradient_accumulation_over_lang_pairs( if src_lengths is not None: report_stats.n_src_words += src_lengths.sum().item() + # tgt_outer corresponds to the target-side input. The expected + # decoder output will be read directly from the batch: + # cf. `onmt.utils.loss.CommonLossCompute._make_shard_state` tgt_outer = batch.tgt bptt = False diff --git a/onmt/transforms/bart.py b/onmt/transforms/denoising.py similarity index 89% rename from onmt/transforms/bart.py rename to onmt/transforms/denoising.py index 5f909317..36fa8c44 100644 --- a/onmt/transforms/bart.py +++ b/onmt/transforms/denoising.py @@ -322,10 +322,13 @@ def __repr__(self): return '{}({})'.format(cls_name, cls_args) -@register_transform(name='bart') -class BARTNoiseTransform(Transform): +@register_transform(name='denoising') +class NoiseTransform(Transform): def __init__(self, opts): + if opts.random_ratio > 0 and opts.denoising_objective == 'mass': + raise ValueError('Random word replacement is incompatible with MASS') super().__init__(opts) + self.denoising_objective = opts.denoising_objective @classmethod def get_specials(cls, opts): @@ -338,8 +341,8 @@ def _set_seed(self, seed): @classmethod def add_options(cls, parser): - """Avalilable options relate to BART.""" - group = parser.add_argument_group("Transform/BART") + """Avalilable options relate to BART/MASS denoising.""" + group = parser.add_argument_group("Transform/Denoising AE") group.add( "--permute_sent_ratio", "-permute_sent_ratio", @@ -361,7 +364,7 @@ def add_options(cls, parser): "-random_ratio", type=float, default=0.0, - help="Instead of using {}, use random token this often.".format(DefaultTokens.MASK), + help=f"Instead of using {DefaultTokens.MASK}, use random token this often. Incompatible with MASS", ) group.add( @@ -394,6 +397,13 @@ def add_options(cls, parser): choices=[-1, 0, 1], help="When masking N tokens, replace with 0, 1, or N tokens. (use -1 for N)", ) + group.add( + "--denoising_objective", + type=str, + default='bart', + choices=['bart', 'mass'], + help='choose between BART-style or MASS-style denoising objectives' + ) @classmethod def require_vocab(cls): @@ -424,13 +434,42 @@ def warm_up(self, vocabs): is_joiner=is_joiner, ) - def apply(self, example, is_train=False, stats=None, **kwargs): + def apply_bart(self, example, is_train=False, stats=None, **kwargs): """Apply BART noise to src side tokens.""" if is_train: src = self.bart_noise.apply(example['src']) example['src'] = src return example + def apply_mass(self, example, is_train=False, stats=None, **kwargs): + """Apply BART noise to src side tokens, then complete as MASS scheme.""" + if is_train: + masked = self.bart_noise.apply(example['src']) + complement_masked = [ + DefaultTokens.MASK if masked_item != DefaultTokens.MASK + else source_item + for masked_item, source_item in zip(masked, example['src']) + ] + labels = [ + DefaultTokens.PAD if cmasked_item == DefaultTokens.MASK + else cmasked_item + for cmasked_item in complement_masked + ] + example = { + 'src': masked, + 'tgt': complement_masked, + 'labels': labels, + } + return example + + def apply(self, example, is_train=False, stats=None, **kwargs): + if self.denoising_objective == 'bart': + return self.apply_bart(example, is_train=False, stats=None, **kwargs) + elif self.denoising_objective == 'mass': + return self.apply_mass(example, is_train=False, stats=None, **kwargs) + else: + raise NotImplementedError('Unknown denoising objective.') + def _repr_args(self): """Return str represent key arguments for BART.""" return repr(self.bart_noise) diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index dc54deba..100e7eb6 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -179,19 +179,19 @@ def __call__(self, batch, output, attns, normalization=1.0, shard_size=0, trunc_ batch_stats.update(stats) return None, batch_stats - def _stats(self, loss, scores, target): + def _stats(self, loss, scores, labels): """ Args: loss (:obj:`FloatTensor`): the loss computed by the loss criterion. scores (:obj:`FloatTensor`): a score for each possible output - target (:obj:`FloatTensor`): true targets + labels (:obj:`FloatTensor`): true targets Returns: :obj:`onmt.utils.Statistics` : statistics for this batch. """ pred = scores.max(1)[1] - non_padding = target.ne(self.padding_idx) - num_correct = pred.eq(target).masked_select(non_padding).sum().item() + non_padding = labels.ne(self.padding_idx) + num_correct = pred.eq(labels).masked_select(non_padding).sum().item() num_non_padding = non_padding.sum().item() return onmt.utils.Statistics(loss.item(), num_non_padding, num_correct) @@ -221,14 +221,14 @@ def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): self.confidence = 1.0 - label_smoothing - def forward(self, output, target): + def forward(self, output, labels): """ output (FloatTensor): batch_size x n_classes - target (LongTensor): batch_size + labels (LongTensor): batch_size """ - model_prob = self.one_hot.repeat(target.size(0), 1) - model_prob.scatter_(1, target.unsqueeze(1), self.confidence) - model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) + model_prob = self.one_hot.repeat(labels.size(0), 1) + model_prob.scatter_(1, labels.unsqueeze(1), self.confidence) + model_prob.masked_fill_((labels == self.ignore_index).unsqueeze(1), 0) return F.kl_div(output, model_prob, reduction='sum') @@ -262,12 +262,14 @@ def _add_coverage_shard_state(self, shard_state, attns): ) shard_state.update({"std_attn": attns.get("std"), "coverage_attn": coverage}) - def _compute_loss(self, batch, output, target, std_attn=None, coverage_attn=None, align_head=None, ref_align=None): + def _compute_loss( + self, batch, output, target, labels, std_attn=None, coverage_attn=None, align_head=None, ref_align=None + ): bottled_output = self._bottle(output) scores = self.generator(bottled_output) - gtruth = target.view(-1) + gtruth = labels.view(-1) loss = self.criterion(scores, gtruth) if self.lambda_coverage != 0.0: @@ -327,7 +329,9 @@ def _make_shard_state(self, batch, output, range_, attns=None): range_end = range_[1] shard_state = { "output": output, + # TODO: target here is likely unnecessary, as it now corresponds to target-side input "target": batch.tgt[range_start:range_end, :, 0], + "labels": batch.labels[range_start:range_end, :, 0], } if self.lambda_coverage != 0.0: self._add_coverage_shard_state(shard_state, attns) diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index a0b4ae1d..e979e338 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -168,7 +168,7 @@ def _get_all_transform(cls, opt): if hasattr(opt, 'lambda_align') and opt.lambda_align > 0.0: if not all_transforms.isdisjoint({'sentencepiece', 'bpe', 'onmt_tokenize'}): raise ValueError('lambda_align is not compatible with on-the-fly tokenization.') - if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'bart'}): + if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'denoising'}): raise ValueError('lambda_align is not compatible yet with potential token deletion/addition.') opt._all_transform = all_transforms