From 21d21e14ba5e9fff0213914a7bece9b598f53b76 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Thu, 21 Sep 2023 11:23:19 +0300 Subject: [PATCH 1/7] mass impl start --- onmt/inputters/dataset.py | 15 ++++- onmt/trainer.py | 2 +- onmt/transforms/bart.py | 128 ++++++++++++++++++++++++++++++++++++++ onmt/utils/loss.py | 7 ++- 4 files changed, 147 insertions(+), 5 deletions(-) diff --git a/onmt/inputters/dataset.py b/onmt/inputters/dataset.py index c0f1b2e6..5cf29f34 100644 --- a/onmt/inputters/dataset.py +++ b/onmt/inputters/dataset.py @@ -18,12 +18,15 @@ class Batch(): src: tuple # of torch Tensors tgt: torch.Tensor + label: torch.Tensor batch_size: int def to(self, device): self.src = (self.src[0].to(device), self.src[1].to(device)) if self.tgt is not None: self.tgt = self.tgt.to(device) + if self.label is not None: + self.label = self.label.to(device) return self @@ -156,8 +159,16 @@ def collate_fn(self, examples): tgt_padidx = self.vocabs['tgt'][DefaultTokens.PAD] src_lengths = torch.tensor([ex['src'].numel() for ex in examples], device='cpu') src = (pad_sequence([ex['src'] for ex in examples], padding_value=src_padidx).unsqueeze(-1), src_lengths) - tgt = pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1) if has_tgt else None - batch = Batch(src, tgt, len(examples)) + if has_tgt: + tgt = pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1) + if 'labels' not in examples[0].keys(): + labels = tgt + else: + labels = pad_sequence([ex['labels'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1) + else: + tgt = None + labels = None + batch = Batch(src, tgt, labels, len(examples)) return batch diff --git a/onmt/trainer.py b/onmt/trainer.py index f0fa8efd..5ab396b3 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -461,7 +461,7 @@ def _gradient_accumulation_over_lang_pairs( seen_comm_batches.add(comm_batch) if self.norm_method == "tokens": num_tokens = ( - batch.tgt[1:, :, 0].ne(self.train_loss_md[f'trainloss{metadata.tgt_lang}'].padding_idx).sum() + batch.labels[1:, :, 0].ne(self.train_loss_md[f'trainloss{metadata.tgt_lang}'].padding_idx).sum() ) normalization += num_tokens.item() else: diff --git a/onmt/transforms/bart.py b/onmt/transforms/bart.py index 5f909317..94a6a79d 100644 --- a/onmt/transforms/bart.py +++ b/onmt/transforms/bart.py @@ -434,3 +434,131 @@ def apply(self, example, is_train=False, stats=None, **kwargs): def _repr_args(self): """Return str represent key arguments for BART.""" return repr(self.bart_noise) + + +@register_transform(name='mass') +class MASSNoiseTransform(Transform): + def __init__(self, opts): + super().__init__(opts) + + @classmethod + def get_specials(cls, opts): + # FIXME: If a different mask token is used, then it is up to you to add it to specials + return ({DefaultTokens.MASK}, set()) + + def _set_seed(self, seed): + """set seed to ensure reproducibility.""" + BARTNoising.set_random_seed(seed) + + @classmethod + def add_options(cls, parser): + """Avalilable options relate to BART.""" + group = parser.add_argument_group("Transform/MASS") + group.add( + "--permute_sent_ratio", + "-permute_sent_ratio", + type=float, + default=0.0, + help="Permute this proportion of sentences " + "(boundaries defined by {}) in all inputs.".format(DefaultTokens.SENT_FULL_STOPS), + ) + group.add("--rotate_ratio", "-rotate_ratio", type=float, default=0.0, help="Rotate this proportion of inputs.") + group.add( + "--insert_ratio", + "-insert_ratio", + type=float, + default=0.0, + help="Insert this percentage of additional random tokens.", + ) + group.add( + "--random_ratio", + "-random_ratio", + type=float, + default=0.0, + help="Instead of using {}, use random token this often.".format(DefaultTokens.MASK), + ) + + group.add( + "--mask_ratio", + "-mask_ratio", + type=float, + default=0.0, + help="Fraction of words/subwords that will be masked.", + ) + group.add( + "--mask_length", + "-mask_length", + type=str, + default="subword", + choices=["subword", "word", "span-poisson"], + help="Length of masking window to apply.", + ) + group.add( + "--poisson_lambda", + "-poisson_lambda", + type=float, + default=3.0, + help="Lambda for Poisson distribution to sample span length if `-mask_length` set to span-poisson.", + ) + group.add( + "--replace_length", + "-replace_length", + type=int, + default=-1, + choices=[-1, 0, 1], + help="When masking N tokens, replace with 0, 1, or N tokens. (use -1 for N)", + ) + + @classmethod + def require_vocab(cls): + """Override this method to inform it need vocab to start.""" + return True + + def warm_up(self, vocabs): + super().warm_up(vocabs) + + subword_type = self.opts.src_subword_type + if self.opts.mask_length == 'subword': + if subword_type == 'none': + raise ValueError( + f'src_subword_type={subword_type} incompatible with ' f'mask_length={self.opts.mask_length}!' + ) + is_joiner = (subword_type == 'bpe') if subword_type != 'none' else None + self.mass_noise = BARTNoising( + self.vocabs['src'].itos, + mask_tok=DefaultTokens.MASK, + mask_ratio=self.opts.mask_ratio, + insert_ratio=self.opts.insert_ratio, + permute_sent_ratio=self.opts.permute_sent_ratio, + poisson_lambda=self.opts.poisson_lambda, + replace_length=self.opts.replace_length, + rotate_ratio=self.opts.rotate_ratio, + mask_length=self.opts.mask_length, + random_ratio=self.opts.random_ratio, + is_joiner=is_joiner, + ) + + def apply(self, example, is_train=False, stats=None, **kwargs): + """Apply BART noise to src side tokens, then complete as MASS scheme.""" + if is_train: + masked = self.mass_noise.apply(example['src']) + complement_masked = [ + DefaultTokens.MASK if masked_item != DefaultTokens.MASK + else source_item + for masked_item, source_item in zip(masked, example['src']) + ] + labels = [ + DefaultTokens.PAD if cmasked_item == DefaultTokens.MASK + else cmasked_item + for cmasked_item in complement_masked + ] + example = { + 'src': masked, + 'tgt': complement_masked, + 'labels': labels, + } + return example + + def _repr_args(self): + """Return str represent key arguments for BART.""" + return repr(self.mass_noise) diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index dc54deba..1b2f2701 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -262,12 +262,14 @@ def _add_coverage_shard_state(self, shard_state, attns): ) shard_state.update({"std_attn": attns.get("std"), "coverage_attn": coverage}) - def _compute_loss(self, batch, output, target, std_attn=None, coverage_attn=None, align_head=None, ref_align=None): + def _compute_loss( + self, batch, output, target, labels, std_attn=None, coverage_attn=None, align_head=None, ref_align=None + ): bottled_output = self._bottle(output) scores = self.generator(bottled_output) - gtruth = target.view(-1) + gtruth = labels.view(-1) loss = self.criterion(scores, gtruth) if self.lambda_coverage != 0.0: @@ -328,6 +330,7 @@ def _make_shard_state(self, batch, output, range_, attns=None): shard_state = { "output": output, "target": batch.tgt[range_start:range_end, :, 0], + "labels": batch.labels[range_start:range_end, :, 0], } if self.lambda_coverage != 0.0: self._add_coverage_shard_state(shard_state, attns) From 9f3aab8926022230fbb34cc54827b1e47a017ef7 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Thu, 21 Sep 2023 11:36:04 +0300 Subject: [PATCH 2/7] merging mass and bart --- onmt/tests/test_transform.py | 6 +- onmt/transforms/bart.py | 139 ++++++----------------------------- 2 files changed, 27 insertions(+), 118 deletions(-) diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py index 234e6ebb..0cffb0c8 100644 --- a/onmt/tests/test_transform.py +++ b/onmt/tests/test_transform.py @@ -22,7 +22,7 @@ def test_transform_register(self): "sentencepiece", "bpe", "onmt_tokenize", - "bart", + "ae_noise", "switchout", "tokendrop", "tokenmask", @@ -30,14 +30,14 @@ def test_transform_register(self): get_transforms_cls(builtin_transform) def test_vocab_required_transform(self): - transforms_cls = get_transforms_cls(["bart", "switchout"]) + transforms_cls = get_transforms_cls(["ae_noise", "switchout"]) opt = Namespace(seed=-1, switchout_temperature=1.0) # transforms that require vocab will not create if not provide vocab transforms = make_transforms(opt, transforms_cls, vocabs=None, task=None) self.assertEqual(len(transforms), 0) with self.assertRaises(ValueError): transforms_cls["switchout"](opt).warm_up(vocabs=None) - transforms_cls["bart"](opt).warm_up(vocabs=None) + transforms_cls["ae_noise"](opt).warm_up(vocabs=None) def test_transform_specials(self): transforms_cls = get_transforms_cls(["prefix"]) diff --git a/onmt/transforms/bart.py b/onmt/transforms/bart.py index 94a6a79d..aa46b062 100644 --- a/onmt/transforms/bart.py +++ b/onmt/transforms/bart.py @@ -322,10 +322,11 @@ def __repr__(self): return '{}({})'.format(cls_name, cls_args) -@register_transform(name='bart') -class BARTNoiseTransform(Transform): +@register_transform(name='ae_noise') +class NoiseTransform(Transform): def __init__(self, opts): super().__init__(opts) + self.denoising_objective = opts.denoising_objective @classmethod def get_specials(cls, opts): @@ -338,8 +339,8 @@ def _set_seed(self, seed): @classmethod def add_options(cls, parser): - """Avalilable options relate to BART.""" - group = parser.add_argument_group("Transform/BART") + """Avalilable options relate to BART/MASS denoising.""" + group = parser.add_argument_group("Transform/Denoising AE") group.add( "--permute_sent_ratio", "-permute_sent_ratio", @@ -394,6 +395,13 @@ def add_options(cls, parser): choices=[-1, 0, 1], help="When masking N tokens, replace with 0, 1, or N tokens. (use -1 for N)", ) + group.add( + "--denoising_objective", + type=str, + default='bart', + choices=['bart', 'mass'], + help='choose between BART-style or MASS-style denoising objectives' + ) @classmethod def require_vocab(cls): @@ -424,124 +432,17 @@ def warm_up(self, vocabs): is_joiner=is_joiner, ) - def apply(self, example, is_train=False, stats=None, **kwargs): + def apply_bart(self, example, is_train=False, stats=None, **kwargs): """Apply BART noise to src side tokens.""" if is_train: src = self.bart_noise.apply(example['src']) example['src'] = src return example - def _repr_args(self): - """Return str represent key arguments for BART.""" - return repr(self.bart_noise) - - -@register_transform(name='mass') -class MASSNoiseTransform(Transform): - def __init__(self, opts): - super().__init__(opts) - - @classmethod - def get_specials(cls, opts): - # FIXME: If a different mask token is used, then it is up to you to add it to specials - return ({DefaultTokens.MASK}, set()) - - def _set_seed(self, seed): - """set seed to ensure reproducibility.""" - BARTNoising.set_random_seed(seed) - - @classmethod - def add_options(cls, parser): - """Avalilable options relate to BART.""" - group = parser.add_argument_group("Transform/MASS") - group.add( - "--permute_sent_ratio", - "-permute_sent_ratio", - type=float, - default=0.0, - help="Permute this proportion of sentences " - "(boundaries defined by {}) in all inputs.".format(DefaultTokens.SENT_FULL_STOPS), - ) - group.add("--rotate_ratio", "-rotate_ratio", type=float, default=0.0, help="Rotate this proportion of inputs.") - group.add( - "--insert_ratio", - "-insert_ratio", - type=float, - default=0.0, - help="Insert this percentage of additional random tokens.", - ) - group.add( - "--random_ratio", - "-random_ratio", - type=float, - default=0.0, - help="Instead of using {}, use random token this often.".format(DefaultTokens.MASK), - ) - - group.add( - "--mask_ratio", - "-mask_ratio", - type=float, - default=0.0, - help="Fraction of words/subwords that will be masked.", - ) - group.add( - "--mask_length", - "-mask_length", - type=str, - default="subword", - choices=["subword", "word", "span-poisson"], - help="Length of masking window to apply.", - ) - group.add( - "--poisson_lambda", - "-poisson_lambda", - type=float, - default=3.0, - help="Lambda for Poisson distribution to sample span length if `-mask_length` set to span-poisson.", - ) - group.add( - "--replace_length", - "-replace_length", - type=int, - default=-1, - choices=[-1, 0, 1], - help="When masking N tokens, replace with 0, 1, or N tokens. (use -1 for N)", - ) - - @classmethod - def require_vocab(cls): - """Override this method to inform it need vocab to start.""" - return True - - def warm_up(self, vocabs): - super().warm_up(vocabs) - - subword_type = self.opts.src_subword_type - if self.opts.mask_length == 'subword': - if subword_type == 'none': - raise ValueError( - f'src_subword_type={subword_type} incompatible with ' f'mask_length={self.opts.mask_length}!' - ) - is_joiner = (subword_type == 'bpe') if subword_type != 'none' else None - self.mass_noise = BARTNoising( - self.vocabs['src'].itos, - mask_tok=DefaultTokens.MASK, - mask_ratio=self.opts.mask_ratio, - insert_ratio=self.opts.insert_ratio, - permute_sent_ratio=self.opts.permute_sent_ratio, - poisson_lambda=self.opts.poisson_lambda, - replace_length=self.opts.replace_length, - rotate_ratio=self.opts.rotate_ratio, - mask_length=self.opts.mask_length, - random_ratio=self.opts.random_ratio, - is_joiner=is_joiner, - ) - - def apply(self, example, is_train=False, stats=None, **kwargs): + def apply_mass(self, example, is_train=False, stats=None, **kwargs): """Apply BART noise to src side tokens, then complete as MASS scheme.""" if is_train: - masked = self.mass_noise.apply(example['src']) + masked = self.bart_noise.apply(example['src']) complement_masked = [ DefaultTokens.MASK if masked_item != DefaultTokens.MASK else source_item @@ -559,6 +460,14 @@ def apply(self, example, is_train=False, stats=None, **kwargs): } return example + def apply(self, example, is_train=False, stats=None, **kwargs): + if self.denoising_objective == 'bart': + return self.apply_bart(example, is_train=False, stats=None, **kwargs) + elif self.denoising_objective == 'mass': + return self.apply_mass(example, is_train=False, stats=None, **kwargs) + else: + raise NotImplementedError('Unknown denoising objective.') + def _repr_args(self): """Return str represent key arguments for BART.""" - return repr(self.mass_noise) + return repr(self.bart_noise) From e801ea4c7d0fef5a7ca089a2f7c622b3399d9ffd Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Thu, 21 Sep 2023 11:38:30 +0300 Subject: [PATCH 3/7] more generic transform name --- docs/source/config_config.md | 4 ++-- examples/config_config.yaml | 6 +++--- onmt/tests/test_subword_marker.py | 2 +- onmt/tests/test_transform.py | 2 +- onmt/transforms/{bart.py => denoising.py} | 0 onmt/utils/parse.py | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) rename onmt/transforms/{bart.py => denoising.py} (100%) diff --git a/docs/source/config_config.md b/docs/source/config_config.md index 650439de..356df0dd 100644 --- a/docs/source/config_config.md +++ b/docs/source/config_config.md @@ -34,7 +34,7 @@ The meta-parameters under the `config_config` key: Path templates for source and target corpora, respectively. The path templates can contain the following variables that will be substituted by `config_config`: -- Directional corpus mode +- Directional corpus mode - `{src_lang}`: The source language of the task - `{tgt_lang}`: The target language of the task - `{lang_pair}`: `{src_lang}-{tgt_lang}` for convenience @@ -99,7 +99,7 @@ Generate translation configs for zero-shot directions. #### `transforms` and `ae_transforms` A list of transforms, for translation tasks and autoencoder tasks, respectively. -Use this to apply subword segmentation, e.g. using `sentencepiece`, and `bart` noise for autoencoder. +Use this to apply subword segmentation, e.g. using `sentencepiece`, and `ae_noise` noise for autoencoder. Both of these may change the sequence length, necessitating a `filtertoolong` transform. #### `enc_sharing_groups` and `dec_sharing_groups` diff --git a/examples/config_config.yaml b/examples/config_config.yaml index 47fe8ca0..c66dc3e4 100644 --- a/examples/config_config.yaml +++ b/examples/config_config.yaml @@ -16,7 +16,7 @@ config_config: ae_transforms: - sentencepiece - filtertoolong - - bart + - ae_noise enc_sharing_groups: - GROUP - FULL @@ -27,7 +27,7 @@ config_config: translation_config_dir: config/translation.opus n_gpus_per_node: 4 n_nodes: 2 - + # Note that this specifies the groups manually instead of clustering groups: "en": "en" @@ -43,7 +43,7 @@ config_config: save_data: generated/opus.spm32k # vocabs serve two purposes: defines the vocab files, and gives the potential languages to consider -src_vocab: +src_vocab: "af": "/scratch/project_2005099/data/opus/prepare_opus_data_tc_out/opusTC.afr.32k.spm.vocab" "da": "/scratch/project_2005099/data/opus/prepare_opus_data_tc_out/opusTC.dan.32k.spm.vocab" "en": "/scratch/project_2005099/data/opus/prepare_opus_data_tc_out/opusTC.eng.32k.spm.vocab" diff --git a/onmt/tests/test_subword_marker.py b/onmt/tests/test_subword_marker.py index 8987cbc0..afa17fcf 100644 --- a/onmt/tests/test_subword_marker.py +++ b/onmt/tests/test_subword_marker.py @@ -1,6 +1,6 @@ import unittest -from onmt.transforms.bart import word_start_finder +from onmt.transforms.denoising import word_start_finder from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer from onmt.constants import SubwordMarker diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py index 0cffb0c8..7d6b58f0 100644 --- a/onmt/tests/test_transform.py +++ b/onmt/tests/test_transform.py @@ -11,7 +11,7 @@ make_transforms, TransformPipe, ) -from onmt.transforms.bart import BARTNoising +from onmt.transforms.denoising import BARTNoising class TestTransform(unittest.TestCase): diff --git a/onmt/transforms/bart.py b/onmt/transforms/denoising.py similarity index 100% rename from onmt/transforms/bart.py rename to onmt/transforms/denoising.py diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index a0b4ae1d..b29729a1 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -168,7 +168,7 @@ def _get_all_transform(cls, opt): if hasattr(opt, 'lambda_align') and opt.lambda_align > 0.0: if not all_transforms.isdisjoint({'sentencepiece', 'bpe', 'onmt_tokenize'}): raise ValueError('lambda_align is not compatible with on-the-fly tokenization.') - if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'bart'}): + if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'ae_noise'}): raise ValueError('lambda_align is not compatible yet with potential token deletion/addition.') opt._all_transform = all_transforms From 4a93e0e5574d30b8e32c2a8747c508a9417d2461 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Fri, 22 Sep 2023 17:23:40 +0300 Subject: [PATCH 4/7] naming convention --- onmt/inputters/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onmt/inputters/dataset.py b/onmt/inputters/dataset.py index 5cf29f34..a8044dec 100644 --- a/onmt/inputters/dataset.py +++ b/onmt/inputters/dataset.py @@ -18,15 +18,15 @@ class Batch(): src: tuple # of torch Tensors tgt: torch.Tensor - label: torch.Tensor + labels: torch.Tensor batch_size: int def to(self, device): self.src = (self.src[0].to(device), self.src[1].to(device)) if self.tgt is not None: self.tgt = self.tgt.to(device) - if self.label is not None: - self.label = self.label.to(device) + if self.labels is not None: + self.labels = self.labels.to(device) return self From 904a706f488bb52b59cb046629d02ca84562fc40 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Mon, 25 Sep 2023 12:11:00 +0300 Subject: [PATCH 5/7] mass, commented --- docs/source/config_config.md | 2 +- examples/config_config.yaml | 2 +- onmt/tests/test_transform.py | 12 +++++++++--- onmt/trainer.py | 3 +++ onmt/transforms/denoising.py | 6 ++++-- onmt/utils/loss.py | 1 + onmt/utils/parse.py | 2 +- 7 files changed, 20 insertions(+), 8 deletions(-) diff --git a/docs/source/config_config.md b/docs/source/config_config.md index 356df0dd..d64773b1 100644 --- a/docs/source/config_config.md +++ b/docs/source/config_config.md @@ -99,7 +99,7 @@ Generate translation configs for zero-shot directions. #### `transforms` and `ae_transforms` A list of transforms, for translation tasks and autoencoder tasks, respectively. -Use this to apply subword segmentation, e.g. using `sentencepiece`, and `ae_noise` noise for autoencoder. +Use this to apply subword segmentation, e.g. using `sentencepiece`, and `denoising` noise for autoencoder. Both of these may change the sequence length, necessitating a `filtertoolong` transform. #### `enc_sharing_groups` and `dec_sharing_groups` diff --git a/examples/config_config.yaml b/examples/config_config.yaml index c66dc3e4..6ff337b7 100644 --- a/examples/config_config.yaml +++ b/examples/config_config.yaml @@ -16,7 +16,7 @@ config_config: ae_transforms: - sentencepiece - filtertoolong - - ae_noise + - denoising enc_sharing_groups: - GROUP - FULL diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py index 7d6b58f0..fc17ca7f 100644 --- a/onmt/tests/test_transform.py +++ b/onmt/tests/test_transform.py @@ -22,7 +22,7 @@ def test_transform_register(self): "sentencepiece", "bpe", "onmt_tokenize", - "ae_noise", + "denoising", "switchout", "tokendrop", "tokenmask", @@ -30,14 +30,14 @@ def test_transform_register(self): get_transforms_cls(builtin_transform) def test_vocab_required_transform(self): - transforms_cls = get_transforms_cls(["ae_noise", "switchout"]) + transforms_cls = get_transforms_cls(["denoising", "switchout"]) opt = Namespace(seed=-1, switchout_temperature=1.0) # transforms that require vocab will not create if not provide vocab transforms = make_transforms(opt, transforms_cls, vocabs=None, task=None) self.assertEqual(len(transforms), 0) with self.assertRaises(ValueError): transforms_cls["switchout"](opt).warm_up(vocabs=None) - transforms_cls["ae_noise"](opt).warm_up(vocabs=None) + transforms_cls["denoising"](opt).warm_up(vocabs=None) def test_transform_specials(self): transforms_cls = get_transforms_cls(["prefix"]) @@ -516,6 +516,12 @@ def test_span_infilling(self): # print(f"Text Span Infilling: {infillied} / {tokens}") # print(n_words, n_masked) + def test_vocab_required_transform(self): + transforms_cls = get_transforms_cls(["denoising"]) + opt = Namespace(random_ratio=1, denoising_objective='mass') + with self.assertRaises(ValueError): + transforms = make_transforms(opt, transforms_cls, vocabs=None, task=None) + class TestFeaturesTransform(unittest.TestCase): def test_inferfeats(self): diff --git a/onmt/trainer.py b/onmt/trainer.py index 5ab396b3..e3bf7d0c 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -481,6 +481,9 @@ def _gradient_accumulation_over_lang_pairs( if src_lengths is not None: report_stats.n_src_words += src_lengths.sum().item() + # tgt_outer corresponds to the target-side input. The expected + # decoder output will be read directly from the batch: + # cf. `onmt.utils.loss.CommonLossCompute._make_shard_state` tgt_outer = batch.tgt bptt = False diff --git a/onmt/transforms/denoising.py b/onmt/transforms/denoising.py index aa46b062..36fa8c44 100644 --- a/onmt/transforms/denoising.py +++ b/onmt/transforms/denoising.py @@ -322,9 +322,11 @@ def __repr__(self): return '{}({})'.format(cls_name, cls_args) -@register_transform(name='ae_noise') +@register_transform(name='denoising') class NoiseTransform(Transform): def __init__(self, opts): + if opts.random_ratio > 0 and opts.denoising_objective == 'mass': + raise ValueError('Random word replacement is incompatible with MASS') super().__init__(opts) self.denoising_objective = opts.denoising_objective @@ -362,7 +364,7 @@ def add_options(cls, parser): "-random_ratio", type=float, default=0.0, - help="Instead of using {}, use random token this often.".format(DefaultTokens.MASK), + help=f"Instead of using {DefaultTokens.MASK}, use random token this often. Incompatible with MASS", ) group.add( diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index 1b2f2701..b05892aa 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -329,6 +329,7 @@ def _make_shard_state(self, batch, output, range_, attns=None): range_end = range_[1] shard_state = { "output": output, + # TODO: target here is likely unnecessary, as it now corresponds to target-side input "target": batch.tgt[range_start:range_end, :, 0], "labels": batch.labels[range_start:range_end, :, 0], } diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index b29729a1..e979e338 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -168,7 +168,7 @@ def _get_all_transform(cls, opt): if hasattr(opt, 'lambda_align') and opt.lambda_align > 0.0: if not all_transforms.isdisjoint({'sentencepiece', 'bpe', 'onmt_tokenize'}): raise ValueError('lambda_align is not compatible with on-the-fly tokenization.') - if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'ae_noise'}): + if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'denoising'}): raise ValueError('lambda_align is not compatible yet with potential token deletion/addition.') opt._all_transform = all_transforms From d13601e7f99c45d33e1cdd311a7dded222404563 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Mon, 25 Sep 2023 12:18:50 +0300 Subject: [PATCH 6/7] lint --- onmt/tests/test_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py index fc17ca7f..32e25f83 100644 --- a/onmt/tests/test_transform.py +++ b/onmt/tests/test_transform.py @@ -520,7 +520,7 @@ def test_vocab_required_transform(self): transforms_cls = get_transforms_cls(["denoising"]) opt = Namespace(random_ratio=1, denoising_objective='mass') with self.assertRaises(ValueError): - transforms = make_transforms(opt, transforms_cls, vocabs=None, task=None) + make_transforms(opt, transforms_cls, vocabs=None, task=None) class TestFeaturesTransform(unittest.TestCase): From 43d24607d0b147c7339c05ee403914ecd47b713b Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Mon, 25 Sep 2023 12:34:53 +0300 Subject: [PATCH 7/7] fixing names --- onmt/utils/loss.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index b05892aa..100e7eb6 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -179,19 +179,19 @@ def __call__(self, batch, output, attns, normalization=1.0, shard_size=0, trunc_ batch_stats.update(stats) return None, batch_stats - def _stats(self, loss, scores, target): + def _stats(self, loss, scores, labels): """ Args: loss (:obj:`FloatTensor`): the loss computed by the loss criterion. scores (:obj:`FloatTensor`): a score for each possible output - target (:obj:`FloatTensor`): true targets + labels (:obj:`FloatTensor`): true targets Returns: :obj:`onmt.utils.Statistics` : statistics for this batch. """ pred = scores.max(1)[1] - non_padding = target.ne(self.padding_idx) - num_correct = pred.eq(target).masked_select(non_padding).sum().item() + non_padding = labels.ne(self.padding_idx) + num_correct = pred.eq(labels).masked_select(non_padding).sum().item() num_non_padding = non_padding.sum().item() return onmt.utils.Statistics(loss.item(), num_non_padding, num_correct) @@ -221,14 +221,14 @@ def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): self.confidence = 1.0 - label_smoothing - def forward(self, output, target): + def forward(self, output, labels): """ output (FloatTensor): batch_size x n_classes - target (LongTensor): batch_size + labels (LongTensor): batch_size """ - model_prob = self.one_hot.repeat(target.size(0), 1) - model_prob.scatter_(1, target.unsqueeze(1), self.confidence) - model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) + model_prob = self.one_hot.repeat(labels.size(0), 1) + model_prob.scatter_(1, labels.unsqueeze(1), self.confidence) + model_prob.masked_fill_((labels == self.ignore_index).unsqueeze(1), 0) return F.kl_div(output, model_prob, reduction='sum')