-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MASS objective + transform #16
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -322,10 +322,13 @@ def __repr__(self): | |
return '{}({})'.format(cls_name, cls_args) | ||
|
||
|
||
@register_transform(name='bart') | ||
class BARTNoiseTransform(Transform): | ||
@register_transform(name='denoising') | ||
class NoiseTransform(Transform): | ||
def __init__(self, opts): | ||
if opts.random_ratio > 0 and opts.denoising_objective == 'mass': | ||
raise ValueError('Random word replacement is incompatible with MASS') | ||
super().__init__(opts) | ||
self.denoising_objective = opts.denoising_objective | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should validate that In BART, it is possible to occasionally substitute a random token instead of the mask token. This slightly alleviates the tendency to copy all unmasked tokens verbatim. In MASS, substituting a random token on the source side would also affect the target sequence: the random token would be complemented into a mask token, and would not contribute to the loss. This does not make sense. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so |
||
|
||
@classmethod | ||
def get_specials(cls, opts): | ||
|
@@ -338,8 +341,8 @@ def _set_seed(self, seed): | |
|
||
@classmethod | ||
def add_options(cls, parser): | ||
"""Avalilable options relate to BART.""" | ||
group = parser.add_argument_group("Transform/BART") | ||
"""Avalilable options relate to BART/MASS denoising.""" | ||
group = parser.add_argument_group("Transform/Denoising AE") | ||
group.add( | ||
"--permute_sent_ratio", | ||
"-permute_sent_ratio", | ||
|
@@ -361,7 +364,7 @@ def add_options(cls, parser): | |
"-random_ratio", | ||
type=float, | ||
default=0.0, | ||
help="Instead of using {}, use random token this often.".format(DefaultTokens.MASK), | ||
help=f"Instead of using {DefaultTokens.MASK}, use random token this often. Incompatible with MASS", | ||
) | ||
|
||
group.add( | ||
|
@@ -394,6 +397,13 @@ def add_options(cls, parser): | |
choices=[-1, 0, 1], | ||
help="When masking N tokens, replace with 0, 1, or N tokens. (use -1 for N)", | ||
) | ||
group.add( | ||
"--denoising_objective", | ||
type=str, | ||
default='bart', | ||
choices=['bart', 'mass'], | ||
help='choose between BART-style or MASS-style denoising objectives' | ||
) | ||
|
||
@classmethod | ||
def require_vocab(cls): | ||
|
@@ -424,13 +434,42 @@ def warm_up(self, vocabs): | |
is_joiner=is_joiner, | ||
) | ||
|
||
def apply(self, example, is_train=False, stats=None, **kwargs): | ||
def apply_bart(self, example, is_train=False, stats=None, **kwargs): | ||
"""Apply BART noise to src side tokens.""" | ||
if is_train: | ||
src = self.bart_noise.apply(example['src']) | ||
example['src'] = src | ||
return example | ||
|
||
def apply_mass(self, example, is_train=False, stats=None, **kwargs): | ||
"""Apply BART noise to src side tokens, then complete as MASS scheme.""" | ||
if is_train: | ||
masked = self.bart_noise.apply(example['src']) | ||
complement_masked = [ | ||
DefaultTokens.MASK if masked_item != DefaultTokens.MASK | ||
else source_item | ||
for masked_item, source_item in zip(masked, example['src']) | ||
] | ||
labels = [ | ||
DefaultTokens.PAD if cmasked_item == DefaultTokens.MASK | ||
else cmasked_item | ||
for cmasked_item in complement_masked | ||
] | ||
example = { | ||
'src': masked, | ||
'tgt': complement_masked, | ||
'labels': labels, | ||
} | ||
return example | ||
|
||
def apply(self, example, is_train=False, stats=None, **kwargs): | ||
if self.denoising_objective == 'bart': | ||
return self.apply_bart(example, is_train=False, stats=None, **kwargs) | ||
elif self.denoising_objective == 'mass': | ||
return self.apply_mass(example, is_train=False, stats=None, **kwargs) | ||
else: | ||
raise NotImplementedError('Unknown denoising objective.') | ||
|
||
def _repr_args(self): | ||
"""Return str represent key arguments for BART.""" | ||
return repr(self.bart_noise) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -179,19 +179,19 @@ def __call__(self, batch, output, attns, normalization=1.0, shard_size=0, trunc_ | |
batch_stats.update(stats) | ||
return None, batch_stats | ||
|
||
def _stats(self, loss, scores, target): | ||
def _stats(self, loss, scores, labels): | ||
""" | ||
Args: | ||
loss (:obj:`FloatTensor`): the loss computed by the loss criterion. | ||
scores (:obj:`FloatTensor`): a score for each possible output | ||
target (:obj:`FloatTensor`): true targets | ||
labels (:obj:`FloatTensor`): true targets | ||
|
||
Returns: | ||
:obj:`onmt.utils.Statistics` : statistics for this batch. | ||
""" | ||
pred = scores.max(1)[1] | ||
non_padding = target.ne(self.padding_idx) | ||
num_correct = pred.eq(target).masked_select(non_padding).sum().item() | ||
non_padding = labels.ne(self.padding_idx) | ||
num_correct = pred.eq(labels).masked_select(non_padding).sum().item() | ||
num_non_padding = non_padding.sum().item() | ||
return onmt.utils.Statistics(loss.item(), num_non_padding, num_correct) | ||
|
||
|
@@ -221,14 +221,14 @@ def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): | |
|
||
self.confidence = 1.0 - label_smoothing | ||
|
||
def forward(self, output, target): | ||
def forward(self, output, labels): | ||
""" | ||
output (FloatTensor): batch_size x n_classes | ||
target (LongTensor): batch_size | ||
labels (LongTensor): batch_size | ||
""" | ||
model_prob = self.one_hot.repeat(target.size(0), 1) | ||
model_prob.scatter_(1, target.unsqueeze(1), self.confidence) | ||
model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) | ||
model_prob = self.one_hot.repeat(labels.size(0), 1) | ||
model_prob.scatter_(1, labels.unsqueeze(1), self.confidence) | ||
model_prob.masked_fill_((labels == self.ignore_index).unsqueeze(1), 0) | ||
|
||
return F.kl_div(output, model_prob, reduction='sum') | ||
|
||
|
@@ -262,12 +262,14 @@ def _add_coverage_shard_state(self, shard_state, attns): | |
) | ||
shard_state.update({"std_attn": attns.get("std"), "coverage_attn": coverage}) | ||
|
||
def _compute_loss(self, batch, output, target, std_attn=None, coverage_attn=None, align_head=None, ref_align=None): | ||
def _compute_loss( | ||
self, batch, output, target, labels, std_attn=None, coverage_attn=None, align_head=None, ref_align=None | ||
): | ||
|
||
bottled_output = self._bottle(output) | ||
|
||
scores = self.generator(bottled_output) | ||
gtruth = target.view(-1) | ||
gtruth = labels.view(-1) | ||
|
||
loss = self.criterion(scores, gtruth) | ||
if self.lambda_coverage != 0.0: | ||
|
@@ -327,7 +329,9 @@ def _make_shard_state(self, batch, output, range_, attns=None): | |
range_end = range_[1] | ||
shard_state = { | ||
"output": output, | ||
# TODO: target here is likely unnecessary, as it now corresponds to target-side input | ||
"target": batch.tgt[range_start:range_end, :, 0], | ||
"labels": batch.labels[range_start:range_end, :, 0], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may not actually be necessary to have separate "target" and "labels" here. As this is the loss, it is only concerned with comparing the output to the desired labels. The decoder input ("target") shouldn't be needed, and as far as I can tell is not used. This way is clearer, though. Also, slicing a tensor and then not touching the slice should not incur much of a cost, as it will be just a view, so removing the "target" might not make any difference. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm strongly in favor of the currently implemented approach here (I'd rather anything that exits the batchers to have roughly the same attributes) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, consistent naming and structure is desirable, and is a good enough motivation for keeping target here. On the topic of consistent naming: I noticed that there are some places (at least in the |
||
} | ||
if self.lambda_coverage != 0.0: | ||
self._add_coverage_shard_state(shard_state, attns) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Later on line 484 tgt is used:
tgt_outer = batch.tgt
.If I read this correctly, the idea is that
tgt_outer
is used only as the input sequence to the decoder. The whole batch is then given to the loss functionself.train_loss_md
which will take care of using the label sequence instead.This is not easy to figure out, because the loss function is wrapped in several layers of onmt scaffolding. A comment would be nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
Full disclosure: I did minimal edits. I don't know why the code doesn't use
batch.tgt
directly line 490. I try not to fix what's not broken.