Skip to content

Commit

Permalink
Merge pull request #16 from Helsinki-NLP/feats/mass
Browse files Browse the repository at this point in the history
MASS objective + transform
  • Loading branch information
TimotheeMickus authored Sep 25, 2023
2 parents e46eed2 + 43d2460 commit 470d466
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 31 deletions.
4 changes: 2 additions & 2 deletions docs/source/config_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ The meta-parameters under the `config_config` key:
Path templates for source and target corpora, respectively.
The path templates can contain the following variables that will be substituted by `config_config`:

- Directional corpus mode
- Directional corpus mode
- `{src_lang}`: The source language of the task
- `{tgt_lang}`: The target language of the task
- `{lang_pair}`: `{src_lang}-{tgt_lang}` for convenience
Expand Down Expand Up @@ -99,7 +99,7 @@ Generate translation configs for zero-shot directions.
#### `transforms` and `ae_transforms`

A list of transforms, for translation tasks and autoencoder tasks, respectively.
Use this to apply subword segmentation, e.g. using `sentencepiece`, and `bart` noise for autoencoder.
Use this to apply subword segmentation, e.g. using `sentencepiece`, and `denoising` noise for autoencoder.
Both of these may change the sequence length, necessitating a `filtertoolong` transform.

#### `enc_sharing_groups` and `dec_sharing_groups`
Expand Down
6 changes: 3 additions & 3 deletions examples/config_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ config_config:
ae_transforms:
- sentencepiece
- filtertoolong
- bart
- denoising
enc_sharing_groups:
- GROUP
- FULL
Expand All @@ -27,7 +27,7 @@ config_config:
translation_config_dir: config/translation.opus
n_gpus_per_node: 4
n_nodes: 2

# Note that this specifies the groups manually instead of clustering
groups:
"en": "en"
Expand All @@ -43,7 +43,7 @@ config_config:

save_data: generated/opus.spm32k
# vocabs serve two purposes: defines the vocab files, and gives the potential languages to consider
src_vocab:
src_vocab:
"af": "/scratch/project_2005099/data/opus/prepare_opus_data_tc_out/opusTC.afr.32k.spm.vocab"
"da": "/scratch/project_2005099/data/opus/prepare_opus_data_tc_out/opusTC.dan.32k.spm.vocab"
"en": "/scratch/project_2005099/data/opus/prepare_opus_data_tc_out/opusTC.eng.32k.spm.vocab"
Expand Down
15 changes: 13 additions & 2 deletions onmt/inputters/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
class Batch():
src: tuple # of torch Tensors
tgt: torch.Tensor
labels: torch.Tensor
batch_size: int

def to(self, device):
self.src = (self.src[0].to(device), self.src[1].to(device))
if self.tgt is not None:
self.tgt = self.tgt.to(device)
if self.labels is not None:
self.labels = self.labels.to(device)
return self


Expand Down Expand Up @@ -156,8 +159,16 @@ def collate_fn(self, examples):
tgt_padidx = self.vocabs['tgt'][DefaultTokens.PAD]
src_lengths = torch.tensor([ex['src'].numel() for ex in examples], device='cpu')
src = (pad_sequence([ex['src'] for ex in examples], padding_value=src_padidx).unsqueeze(-1), src_lengths)
tgt = pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1) if has_tgt else None
batch = Batch(src, tgt, len(examples))
if has_tgt:
tgt = pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1)
if 'labels' not in examples[0].keys():
labels = tgt
else:
labels = pad_sequence([ex['labels'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1)
else:
tgt = None
labels = None
batch = Batch(src, tgt, labels, len(examples))
return batch


Expand Down
2 changes: 1 addition & 1 deletion onmt/tests/test_subword_marker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

from onmt.transforms.bart import word_start_finder
from onmt.transforms.denoising import word_start_finder
from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer
from onmt.constants import SubwordMarker

Expand Down
14 changes: 10 additions & 4 deletions onmt/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
make_transforms,
TransformPipe,
)
from onmt.transforms.bart import BARTNoising
from onmt.transforms.denoising import BARTNoising


class TestTransform(unittest.TestCase):
Expand All @@ -22,22 +22,22 @@ def test_transform_register(self):
"sentencepiece",
"bpe",
"onmt_tokenize",
"bart",
"denoising",
"switchout",
"tokendrop",
"tokenmask",
]
get_transforms_cls(builtin_transform)

def test_vocab_required_transform(self):
transforms_cls = get_transforms_cls(["bart", "switchout"])
transforms_cls = get_transforms_cls(["denoising", "switchout"])
opt = Namespace(seed=-1, switchout_temperature=1.0)
# transforms that require vocab will not create if not provide vocab
transforms = make_transforms(opt, transforms_cls, vocabs=None, task=None)
self.assertEqual(len(transforms), 0)
with self.assertRaises(ValueError):
transforms_cls["switchout"](opt).warm_up(vocabs=None)
transforms_cls["bart"](opt).warm_up(vocabs=None)
transforms_cls["denoising"](opt).warm_up(vocabs=None)

def test_transform_specials(self):
transforms_cls = get_transforms_cls(["prefix"])
Expand Down Expand Up @@ -516,6 +516,12 @@ def test_span_infilling(self):
# print(f"Text Span Infilling: {infillied} / {tokens}")
# print(n_words, n_masked)

def test_vocab_required_transform(self):
transforms_cls = get_transforms_cls(["denoising"])
opt = Namespace(random_ratio=1, denoising_objective='mass')
with self.assertRaises(ValueError):
make_transforms(opt, transforms_cls, vocabs=None, task=None)


class TestFeaturesTransform(unittest.TestCase):
def test_inferfeats(self):
Expand Down
5 changes: 4 additions & 1 deletion onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def _gradient_accumulation_over_lang_pairs(
seen_comm_batches.add(comm_batch)
if self.norm_method == "tokens":
num_tokens = (
batch.tgt[1:, :, 0].ne(self.train_loss_md[f'trainloss{metadata.tgt_lang}'].padding_idx).sum()
batch.labels[1:, :, 0].ne(self.train_loss_md[f'trainloss{metadata.tgt_lang}'].padding_idx).sum()
)
normalization += num_tokens.item()
else:
Expand All @@ -485,6 +485,9 @@ def _gradient_accumulation_over_lang_pairs(
if src_lengths is not None:
report_stats.n_src_words += src_lengths.sum().item()

# tgt_outer corresponds to the target-side input. The expected
# decoder output will be read directly from the batch:
# cf. `onmt.utils.loss.CommonLossCompute._make_shard_state`
tgt_outer = batch.tgt

bptt = False
Expand Down
51 changes: 45 additions & 6 deletions onmt/transforms/bart.py → onmt/transforms/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,13 @@ def __repr__(self):
return '{}({})'.format(cls_name, cls_args)


@register_transform(name='bart')
class BARTNoiseTransform(Transform):
@register_transform(name='denoising')
class NoiseTransform(Transform):
def __init__(self, opts):
if opts.random_ratio > 0 and opts.denoising_objective == 'mass':
raise ValueError('Random word replacement is incompatible with MASS')
super().__init__(opts)
self.denoising_objective = opts.denoising_objective

@classmethod
def get_specials(cls, opts):
Expand All @@ -338,8 +341,8 @@ def _set_seed(self, seed):

@classmethod
def add_options(cls, parser):
"""Avalilable options relate to BART."""
group = parser.add_argument_group("Transform/BART")
"""Avalilable options relate to BART/MASS denoising."""
group = parser.add_argument_group("Transform/Denoising AE")
group.add(
"--permute_sent_ratio",
"-permute_sent_ratio",
Expand All @@ -361,7 +364,7 @@ def add_options(cls, parser):
"-random_ratio",
type=float,
default=0.0,
help="Instead of using {}, use random token this often.".format(DefaultTokens.MASK),
help=f"Instead of using {DefaultTokens.MASK}, use random token this often. Incompatible with MASS",
)

group.add(
Expand Down Expand Up @@ -394,6 +397,13 @@ def add_options(cls, parser):
choices=[-1, 0, 1],
help="When masking N tokens, replace with 0, 1, or N tokens. (use -1 for N)",
)
group.add(
"--denoising_objective",
type=str,
default='bart',
choices=['bart', 'mass'],
help='choose between BART-style or MASS-style denoising objectives'
)

@classmethod
def require_vocab(cls):
Expand Down Expand Up @@ -424,13 +434,42 @@ def warm_up(self, vocabs):
is_joiner=is_joiner,
)

def apply(self, example, is_train=False, stats=None, **kwargs):
def apply_bart(self, example, is_train=False, stats=None, **kwargs):
"""Apply BART noise to src side tokens."""
if is_train:
src = self.bart_noise.apply(example['src'])
example['src'] = src
return example

def apply_mass(self, example, is_train=False, stats=None, **kwargs):
"""Apply BART noise to src side tokens, then complete as MASS scheme."""
if is_train:
masked = self.bart_noise.apply(example['src'])
complement_masked = [
DefaultTokens.MASK if masked_item != DefaultTokens.MASK
else source_item
for masked_item, source_item in zip(masked, example['src'])
]
labels = [
DefaultTokens.PAD if cmasked_item == DefaultTokens.MASK
else cmasked_item
for cmasked_item in complement_masked
]
example = {
'src': masked,
'tgt': complement_masked,
'labels': labels,
}
return example

def apply(self, example, is_train=False, stats=None, **kwargs):
if self.denoising_objective == 'bart':
return self.apply_bart(example, is_train=False, stats=None, **kwargs)
elif self.denoising_objective == 'mass':
return self.apply_mass(example, is_train=False, stats=None, **kwargs)
else:
raise NotImplementedError('Unknown denoising objective.')

def _repr_args(self):
"""Return str represent key arguments for BART."""
return repr(self.bart_noise)
26 changes: 15 additions & 11 deletions onmt/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,19 +179,19 @@ def __call__(self, batch, output, attns, normalization=1.0, shard_size=0, trunc_
batch_stats.update(stats)
return None, batch_stats

def _stats(self, loss, scores, target):
def _stats(self, loss, scores, labels):
"""
Args:
loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
scores (:obj:`FloatTensor`): a score for each possible output
target (:obj:`FloatTensor`): true targets
labels (:obj:`FloatTensor`): true targets
Returns:
:obj:`onmt.utils.Statistics` : statistics for this batch.
"""
pred = scores.max(1)[1]
non_padding = target.ne(self.padding_idx)
num_correct = pred.eq(target).masked_select(non_padding).sum().item()
non_padding = labels.ne(self.padding_idx)
num_correct = pred.eq(labels).masked_select(non_padding).sum().item()
num_non_padding = non_padding.sum().item()
return onmt.utils.Statistics(loss.item(), num_non_padding, num_correct)

Expand Down Expand Up @@ -221,14 +221,14 @@ def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100):

self.confidence = 1.0 - label_smoothing

def forward(self, output, target):
def forward(self, output, labels):
"""
output (FloatTensor): batch_size x n_classes
target (LongTensor): batch_size
labels (LongTensor): batch_size
"""
model_prob = self.one_hot.repeat(target.size(0), 1)
model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)
model_prob = self.one_hot.repeat(labels.size(0), 1)
model_prob.scatter_(1, labels.unsqueeze(1), self.confidence)
model_prob.masked_fill_((labels == self.ignore_index).unsqueeze(1), 0)

return F.kl_div(output, model_prob, reduction='sum')

Expand Down Expand Up @@ -262,12 +262,14 @@ def _add_coverage_shard_state(self, shard_state, attns):
)
shard_state.update({"std_attn": attns.get("std"), "coverage_attn": coverage})

def _compute_loss(self, batch, output, target, std_attn=None, coverage_attn=None, align_head=None, ref_align=None):
def _compute_loss(
self, batch, output, target, labels, std_attn=None, coverage_attn=None, align_head=None, ref_align=None
):

bottled_output = self._bottle(output)

scores = self.generator(bottled_output)
gtruth = target.view(-1)
gtruth = labels.view(-1)

loss = self.criterion(scores, gtruth)
if self.lambda_coverage != 0.0:
Expand Down Expand Up @@ -327,7 +329,9 @@ def _make_shard_state(self, batch, output, range_, attns=None):
range_end = range_[1]
shard_state = {
"output": output,
# TODO: target here is likely unnecessary, as it now corresponds to target-side input
"target": batch.tgt[range_start:range_end, :, 0],
"labels": batch.labels[range_start:range_end, :, 0],
}
if self.lambda_coverage != 0.0:
self._add_coverage_shard_state(shard_state, attns)
Expand Down
2 changes: 1 addition & 1 deletion onmt/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _get_all_transform(cls, opt):
if hasattr(opt, 'lambda_align') and opt.lambda_align > 0.0:
if not all_transforms.isdisjoint({'sentencepiece', 'bpe', 'onmt_tokenize'}):
raise ValueError('lambda_align is not compatible with on-the-fly tokenization.')
if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'bart'}):
if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'denoising'}):
raise ValueError('lambda_align is not compatible yet with potential token deletion/addition.')
opt._all_transform = all_transforms

Expand Down

0 comments on commit 470d466

Please sign in to comment.