Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MASS objective + transform #16

Merged
merged 7 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/config_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ The meta-parameters under the `config_config` key:
Path templates for source and target corpora, respectively.
The path templates can contain the following variables that will be substituted by `config_config`:

- Directional corpus mode
- Directional corpus mode
- `{src_lang}`: The source language of the task
- `{tgt_lang}`: The target language of the task
- `{lang_pair}`: `{src_lang}-{tgt_lang}` for convenience
Expand Down Expand Up @@ -99,7 +99,7 @@ Generate translation configs for zero-shot directions.
#### `transforms` and `ae_transforms`

A list of transforms, for translation tasks and autoencoder tasks, respectively.
Use this to apply subword segmentation, e.g. using `sentencepiece`, and `bart` noise for autoencoder.
Use this to apply subword segmentation, e.g. using `sentencepiece`, and `denoising` noise for autoencoder.
Both of these may change the sequence length, necessitating a `filtertoolong` transform.

#### `enc_sharing_groups` and `dec_sharing_groups`
Expand Down
6 changes: 3 additions & 3 deletions examples/config_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ config_config:
ae_transforms:
- sentencepiece
- filtertoolong
- bart
- denoising
enc_sharing_groups:
- GROUP
- FULL
Expand All @@ -27,7 +27,7 @@ config_config:
translation_config_dir: config/translation.opus
n_gpus_per_node: 4
n_nodes: 2

# Note that this specifies the groups manually instead of clustering
groups:
"en": "en"
Expand All @@ -43,7 +43,7 @@ config_config:

save_data: generated/opus.spm32k
# vocabs serve two purposes: defines the vocab files, and gives the potential languages to consider
src_vocab:
src_vocab:
"af": "/scratch/project_2005099/data/opus/prepare_opus_data_tc_out/opusTC.afr.32k.spm.vocab"
"da": "/scratch/project_2005099/data/opus/prepare_opus_data_tc_out/opusTC.dan.32k.spm.vocab"
"en": "/scratch/project_2005099/data/opus/prepare_opus_data_tc_out/opusTC.eng.32k.spm.vocab"
Expand Down
15 changes: 13 additions & 2 deletions onmt/inputters/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
class Batch():
src: tuple # of torch Tensors
tgt: torch.Tensor
labels: torch.Tensor
batch_size: int

def to(self, device):
self.src = (self.src[0].to(device), self.src[1].to(device))
if self.tgt is not None:
self.tgt = self.tgt.to(device)
if self.labels is not None:
self.labels = self.labels.to(device)
return self


Expand Down Expand Up @@ -156,8 +159,16 @@ def collate_fn(self, examples):
tgt_padidx = self.vocabs['tgt'][DefaultTokens.PAD]
src_lengths = torch.tensor([ex['src'].numel() for ex in examples], device='cpu')
src = (pad_sequence([ex['src'] for ex in examples], padding_value=src_padidx).unsqueeze(-1), src_lengths)
tgt = pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1) if has_tgt else None
batch = Batch(src, tgt, len(examples))
if has_tgt:
tgt = pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1)
if 'labels' not in examples[0].keys():
labels = tgt
else:
labels = pad_sequence([ex['labels'] for ex in examples], padding_value=tgt_padidx).unsqueeze(-1)
else:
tgt = None
labels = None
batch = Batch(src, tgt, labels, len(examples))
return batch


Expand Down
2 changes: 1 addition & 1 deletion onmt/tests/test_subword_marker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

from onmt.transforms.bart import word_start_finder
from onmt.transforms.denoising import word_start_finder
from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer
from onmt.constants import SubwordMarker

Expand Down
14 changes: 10 additions & 4 deletions onmt/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
make_transforms,
TransformPipe,
)
from onmt.transforms.bart import BARTNoising
from onmt.transforms.denoising import BARTNoising


class TestTransform(unittest.TestCase):
Expand All @@ -22,22 +22,22 @@ def test_transform_register(self):
"sentencepiece",
"bpe",
"onmt_tokenize",
"bart",
"denoising",
"switchout",
"tokendrop",
"tokenmask",
]
get_transforms_cls(builtin_transform)

def test_vocab_required_transform(self):
transforms_cls = get_transforms_cls(["bart", "switchout"])
transforms_cls = get_transforms_cls(["denoising", "switchout"])
opt = Namespace(seed=-1, switchout_temperature=1.0)
# transforms that require vocab will not create if not provide vocab
transforms = make_transforms(opt, transforms_cls, vocabs=None, task=None)
self.assertEqual(len(transforms), 0)
with self.assertRaises(ValueError):
transforms_cls["switchout"](opt).warm_up(vocabs=None)
transforms_cls["bart"](opt).warm_up(vocabs=None)
transforms_cls["denoising"](opt).warm_up(vocabs=None)

def test_transform_specials(self):
transforms_cls = get_transforms_cls(["prefix"])
Expand Down Expand Up @@ -516,6 +516,12 @@ def test_span_infilling(self):
# print(f"Text Span Infilling: {infillied} / {tokens}")
# print(n_words, n_masked)

def test_vocab_required_transform(self):
transforms_cls = get_transforms_cls(["denoising"])
opt = Namespace(random_ratio=1, denoising_objective='mass')
with self.assertRaises(ValueError):
make_transforms(opt, transforms_cls, vocabs=None, task=None)


class TestFeaturesTransform(unittest.TestCase):
def test_inferfeats(self):
Expand Down
5 changes: 4 additions & 1 deletion onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def _gradient_accumulation_over_lang_pairs(
seen_comm_batches.add(comm_batch)
if self.norm_method == "tokens":
num_tokens = (
batch.tgt[1:, :, 0].ne(self.train_loss_md[f'trainloss{metadata.tgt_lang}'].padding_idx).sum()
batch.labels[1:, :, 0].ne(self.train_loss_md[f'trainloss{metadata.tgt_lang}'].padding_idx).sum()
)
normalization += num_tokens.item()
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later on line 484 tgt is used: tgt_outer = batch.tgt.

If I read this correctly, the idea is that tgt_outer is used only as the input sequence to the decoder. The whole batch is then given to the loss function self.train_loss_md which will take care of using the label sequence instead.

This is not easy to figure out, because the loss function is wrapped in several layers of onmt scaffolding. A comment would be nice.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Full disclosure: I did minimal edits. I don't know why the code doesn't use batch.tgt directly line 490. I try not to fix what's not broken.

Expand All @@ -481,6 +481,9 @@ def _gradient_accumulation_over_lang_pairs(
if src_lengths is not None:
report_stats.n_src_words += src_lengths.sum().item()

# tgt_outer corresponds to the target-side input. The expected
# decoder output will be read directly from the batch:
# cf. `onmt.utils.loss.CommonLossCompute._make_shard_state`
tgt_outer = batch.tgt

bptt = False
Expand Down
51 changes: 45 additions & 6 deletions onmt/transforms/bart.py → onmt/transforms/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,13 @@ def __repr__(self):
return '{}({})'.format(cls_name, cls_args)


@register_transform(name='bart')
class BARTNoiseTransform(Transform):
@register_transform(name='denoising')
class NoiseTransform(Transform):
def __init__(self, opts):
if opts.random_ratio > 0 and opts.denoising_objective == 'mass':
raise ValueError('Random word replacement is incompatible with MASS')
super().__init__(opts)
self.denoising_objective = opts.denoising_objective
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should validate that mask-random is not used with the MASS objective.

In BART, it is possible to occasionally substitute a random token instead of the mask token. This slightly alleviates the tendency to copy all unmasked tokens verbatim.

In MASS, substituting a random token on the source side would also affect the target sequence: the random token would be complemented into a mask token, and would not contribute to the loss. This does not make sense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so random_ratio > 0


@classmethod
def get_specials(cls, opts):
Expand All @@ -338,8 +341,8 @@ def _set_seed(self, seed):

@classmethod
def add_options(cls, parser):
"""Avalilable options relate to BART."""
group = parser.add_argument_group("Transform/BART")
"""Avalilable options relate to BART/MASS denoising."""
group = parser.add_argument_group("Transform/Denoising AE")
group.add(
"--permute_sent_ratio",
"-permute_sent_ratio",
Expand All @@ -361,7 +364,7 @@ def add_options(cls, parser):
"-random_ratio",
type=float,
default=0.0,
help="Instead of using {}, use random token this often.".format(DefaultTokens.MASK),
help=f"Instead of using {DefaultTokens.MASK}, use random token this often. Incompatible with MASS",
)

group.add(
Expand Down Expand Up @@ -394,6 +397,13 @@ def add_options(cls, parser):
choices=[-1, 0, 1],
help="When masking N tokens, replace with 0, 1, or N tokens. (use -1 for N)",
)
group.add(
"--denoising_objective",
type=str,
default='bart',
choices=['bart', 'mass'],
help='choose between BART-style or MASS-style denoising objectives'
)

@classmethod
def require_vocab(cls):
Expand Down Expand Up @@ -424,13 +434,42 @@ def warm_up(self, vocabs):
is_joiner=is_joiner,
)

def apply(self, example, is_train=False, stats=None, **kwargs):
def apply_bart(self, example, is_train=False, stats=None, **kwargs):
"""Apply BART noise to src side tokens."""
if is_train:
src = self.bart_noise.apply(example['src'])
example['src'] = src
return example

def apply_mass(self, example, is_train=False, stats=None, **kwargs):
"""Apply BART noise to src side tokens, then complete as MASS scheme."""
if is_train:
masked = self.bart_noise.apply(example['src'])
complement_masked = [
DefaultTokens.MASK if masked_item != DefaultTokens.MASK
else source_item
for masked_item, source_item in zip(masked, example['src'])
]
labels = [
DefaultTokens.PAD if cmasked_item == DefaultTokens.MASK
else cmasked_item
for cmasked_item in complement_masked
]
example = {
'src': masked,
'tgt': complement_masked,
'labels': labels,
}
return example

def apply(self, example, is_train=False, stats=None, **kwargs):
if self.denoising_objective == 'bart':
return self.apply_bart(example, is_train=False, stats=None, **kwargs)
elif self.denoising_objective == 'mass':
return self.apply_mass(example, is_train=False, stats=None, **kwargs)
else:
raise NotImplementedError('Unknown denoising objective.')

def _repr_args(self):
"""Return str represent key arguments for BART."""
return repr(self.bart_noise)
26 changes: 15 additions & 11 deletions onmt/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,19 +179,19 @@ def __call__(self, batch, output, attns, normalization=1.0, shard_size=0, trunc_
batch_stats.update(stats)
return None, batch_stats

def _stats(self, loss, scores, target):
def _stats(self, loss, scores, labels):
"""
Args:
loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
scores (:obj:`FloatTensor`): a score for each possible output
target (:obj:`FloatTensor`): true targets
labels (:obj:`FloatTensor`): true targets

Returns:
:obj:`onmt.utils.Statistics` : statistics for this batch.
"""
pred = scores.max(1)[1]
non_padding = target.ne(self.padding_idx)
num_correct = pred.eq(target).masked_select(non_padding).sum().item()
non_padding = labels.ne(self.padding_idx)
num_correct = pred.eq(labels).masked_select(non_padding).sum().item()
num_non_padding = non_padding.sum().item()
return onmt.utils.Statistics(loss.item(), num_non_padding, num_correct)

Expand Down Expand Up @@ -221,14 +221,14 @@ def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100):

self.confidence = 1.0 - label_smoothing

def forward(self, output, target):
def forward(self, output, labels):
"""
output (FloatTensor): batch_size x n_classes
target (LongTensor): batch_size
labels (LongTensor): batch_size
"""
model_prob = self.one_hot.repeat(target.size(0), 1)
model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)
model_prob = self.one_hot.repeat(labels.size(0), 1)
model_prob.scatter_(1, labels.unsqueeze(1), self.confidence)
model_prob.masked_fill_((labels == self.ignore_index).unsqueeze(1), 0)

return F.kl_div(output, model_prob, reduction='sum')

Expand Down Expand Up @@ -262,12 +262,14 @@ def _add_coverage_shard_state(self, shard_state, attns):
)
shard_state.update({"std_attn": attns.get("std"), "coverage_attn": coverage})

def _compute_loss(self, batch, output, target, std_attn=None, coverage_attn=None, align_head=None, ref_align=None):
def _compute_loss(
self, batch, output, target, labels, std_attn=None, coverage_attn=None, align_head=None, ref_align=None
):

bottled_output = self._bottle(output)

scores = self.generator(bottled_output)
gtruth = target.view(-1)
gtruth = labels.view(-1)

loss = self.criterion(scores, gtruth)
if self.lambda_coverage != 0.0:
Expand Down Expand Up @@ -327,7 +329,9 @@ def _make_shard_state(self, batch, output, range_, attns=None):
range_end = range_[1]
shard_state = {
"output": output,
# TODO: target here is likely unnecessary, as it now corresponds to target-side input
"target": batch.tgt[range_start:range_end, :, 0],
"labels": batch.labels[range_start:range_end, :, 0],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may not actually be necessary to have separate "target" and "labels" here. As this is the loss, it is only concerned with comparing the output to the desired labels. The decoder input ("target") shouldn't be needed, and as far as I can tell is not used.

This way is clearer, though. Also, slicing a tensor and then not touching the slice should not incur much of a cost, as it will be just a view, so removing the "target" might not make any difference.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm strongly in favor of the currently implemented approach here (I'd rather anything that exits the batchers to have roughly the same attributes)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, consistent naming and structure is desirable, and is a good enough motivation for keeping target here.

On the topic of consistent naming: I noticed that there are some places (at least in the _stats method) where the name target is used, but it now actually refers to labels. _stats takes gtruth, which was changed in this PR.

}
if self.lambda_coverage != 0.0:
self._add_coverage_shard_state(shard_state, attns)
Expand Down
2 changes: 1 addition & 1 deletion onmt/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _get_all_transform(cls, opt):
if hasattr(opt, 'lambda_align') and opt.lambda_align > 0.0:
if not all_transforms.isdisjoint({'sentencepiece', 'bpe', 'onmt_tokenize'}):
raise ValueError('lambda_align is not compatible with on-the-fly tokenization.')
if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'bart'}):
if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'denoising'}):
raise ValueError('lambda_align is not compatible yet with potential token deletion/addition.')
opt._all_transform = all_transforms

Expand Down
Loading