diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index 2f6187d9ed..05797d36c3 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -186,7 +186,8 @@ def _forward(self, *args, **kwargs): def _compute_dec_mask(self, tgt_pad_mask, future): tgt_len = tgt_pad_mask.size(-1) - if not future: # apply future_mask, result mask in (B, T, T) + if not future: + # Add triangular future_mask and pad_mask, result mask in (B, T, T). future_mask = torch.ones( [tgt_len, tgt_len], device=tgt_pad_mask.device, @@ -197,9 +198,14 @@ def _compute_dec_mask(self, tgt_pad_mask, future): future_mask = future_mask.triu_(-self.sliding_window) future_mask = future_mask.bool() future_mask = ~future_mask.view(1, tgt_len, tgt_len) - + # Patch for scaled dot product attention. + patch_mask = ~torch.all( + tgt_pad_mask + future_mask, dim=2, keepdim=True + ).expand_as(tgt_pad_mask + future_mask) dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) - else: # only mask padding, result mask in (B, 1, T) + dec_mask = torch.logical_and(dec_mask, patch_mask) + else: + # Only mask padding, result mask in (B, 1, T). dec_mask = tgt_pad_mask return dec_mask @@ -717,7 +723,9 @@ def _forward( dec_mask = None if layer_in.size(1) > 1: - # masking is necessary when sequence length is greater than one + # Masking is necessary when sequence length is greater than one + # The decoding has not started yet, + # we compute the scores on the source tokens in one shot. dec_mask = self._compute_dec_mask(tgt_pad_mask, future) dec_mask = dec_mask.unsqueeze(1) dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1) @@ -859,8 +867,11 @@ def detach_state(self): def forward(self, tgt, enc_out=None, step=None, **kwargs): """Decode, possibly stepwise.""" if step == 0: + # decoding mode. + # Initialize KV cache. self._init_cache(tgt) elif step is None: + # training mode. for layer in self.transformer_layers: layer.self_attn.layer_cache = ( False, diff --git a/onmt/inputters/dynamic_iterator.py b/onmt/inputters/dynamic_iterator.py index 23a9a5880b..8289d1a3ce 100644 --- a/onmt/inputters/dynamic_iterator.py +++ b/onmt/inputters/dynamic_iterator.py @@ -1,7 +1,7 @@ """Module that contain iterator used for dynamic data.""" import torch from itertools import cycle -from onmt.constants import CorpusTask +from onmt.constants import CorpusTask, ModelTask from onmt.inputters.text_corpus import get_corpora, build_corpora_iters from onmt.inputters.text_utils import ( text_sort_key, @@ -164,6 +164,10 @@ def __init__( self.skip_empty_level = skip_empty_level self.random_shuffler = RandomShuffler() self.bucket_idx = 0 + if task != CorpusTask.TRAIN and vocabs["data_task"] == ModelTask.LANGUAGE_MODEL: + self.left_pad = True + else: + self.left_pad = False @classmethod def from_opt( @@ -354,7 +358,9 @@ def __iter__(self): # within the batch if self.task == CorpusTask.TRAIN: minibatch.sort(key=lambda x: self.sort_key(x[0]), reverse=True) - tensor_batch = tensorify(self.vocabs, minibatch, self.device) + tensor_batch = tensorify( + self.vocabs, minibatch, self.device, self.left_pad + ) yield (tensor_batch, bucket_idx) diff --git a/onmt/inputters/text_utils.py b/onmt/inputters/text_utils.py index 4a145842e3..c6e9e8ff1b 100644 --- a/onmt/inputters/text_utils.py +++ b/onmt/inputters/text_utils.py @@ -168,7 +168,7 @@ def parse_align_idx(align_pharaoh): return flatten_align_idx -def tensorify(vocabs, minibatch, device): +def tensorify(vocabs, minibatch, device, left_pad=False): """ This function transforms a batch of example in tensors Each example looks like @@ -193,21 +193,37 @@ def tensorify(vocabs, minibatch, device): } """ tensor_batch = {} - tbatchsrc = [ - torch.tensor(ex["src"]["src_ids"], dtype=torch.long, device=device) - for ex, indice in minibatch - ] + if left_pad: + tbatchsrc = [ + torch.tensor(ex["src"]["src_ids"], dtype=torch.long, device=device).flip( + dims=[0] + ) + for ex, indice in minibatch + ] + else: + tbatchsrc = [ + torch.tensor(ex["src"]["src_ids"], dtype=torch.long, device=device) + for ex, indice in minibatch + ] padidx = vocabs["src"][DefaultTokens.PAD] tbatchsrc = pad_sequence(tbatchsrc, batch_first=True, padding_value=padidx) if "feats" in minibatch[0][0]["src"]: tbatchfs = [tbatchsrc] for feat_id in range(len(minibatch[0][0]["src"]["feats"])): - tbatchfeat = [ - torch.tensor( - ex["src"]["feats"][feat_id], dtype=torch.long, device=device - ) - for ex, indice in minibatch - ] + if left_pad: + tbatchfeat = [ + torch.tensor( + ex["src"]["feats"][feat_id], dtype=torch.long, device=device + ).flip(dims=[0]) + for ex, indice in minibatch + ] + else: + tbatchfeat = [ + torch.tensor( + ex["src"]["feats"][feat_id], dtype=torch.long, device=device + ) + for ex, indice in minibatch + ] padidx = vocabs["src_feats"][feat_id][DefaultTokens.PAD] tbatchfeat = pad_sequence( tbatchfeat, batch_first=True, padding_value=padidx @@ -218,7 +234,10 @@ def tensorify(vocabs, minibatch, device): # Need to add features in last dimensions tbatchsrc = tbatchsrc[:, :, None] - tensor_batch["src"] = tbatchsrc + if left_pad: + tensor_batch["src"] = tbatchsrc.flip(dims=[1]) + else: + tensor_batch["src"] = tbatchsrc tensor_batch["srclen"] = torch.tensor( [len(ex["src"]["src_ids"]) for ex, indice in minibatch], diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 46e1780d36..43eb7e8536 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -405,6 +405,7 @@ def forward( # 1) Project key, value, and query. # as a reminder at training layer_cache[0] remains False if self.layer_cache[0]: + # Retrieve keys and values from the KV cache (decoding mode only). if self.attn_type == "self": query, key, value = ( self.linear_query(query), @@ -451,6 +452,7 @@ def forward( self.layer_cache[1]["keys"] = key self.layer_cache[1]["values"] = value else: + # Retrieve keys and values from linear layers (training mode). key = self.maybe_ckpt(self.linear_keys, key) value = self.maybe_ckpt(self.linear_values, value) query = self.maybe_ckpt(self.linear_query, query) @@ -491,12 +493,12 @@ def forward( self.flash2 and l > 256 # https://github.com/Dao-AILab/flash-attention/issues/591 ) - if ( self.max_relative_positions in [-1, 0] and not return_attn and query.device != torch.device("cpu") ): + # Apply flash2 attention. causal = self.is_decoder and self.attn_type == "self" and mask is not None if self.is_decoder and self.attn_type == "self" and flash2: if causal: @@ -514,6 +516,7 @@ def forward( window_size=window_size, ).transpose(1, 2) else: + # Apply scaled dot product attention. with torch.backends.cuda.sdp_kernel( enable_flash=False, enable_math=True, enable_mem_efficient=True ): diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 10f89665b5..856f786b3c 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -658,7 +658,6 @@ def _decode_and_generate( step=step, return_attn=self.global_scorer.has_cov_pen or return_attn, ) - # Generator forward. if not self.copy_attn: if "std" in dec_attn: @@ -988,16 +987,6 @@ def _align_forward(self, batch, predictions): def translate_batch(self, batch, attn_debug): """Translate a batch of sentences.""" - batch_size = len(batch["srclen"]) - if batch_size != 1: - warning_msg = ( - "GeneratorLM does not support batch_size != 1" - " nicely. You can remove this limitation here." - " With batch_size > 1 the end of each input is" - " repeated until the input is finished. Then" - " generation will start." - ) - self._log(warning_msg) with torch.no_grad(): if self.sample_from_topk != 0 or self.sample_from_topp != 0: decode_strategy = GreedySearchLM( @@ -1061,7 +1050,7 @@ def tile_to_beam_size_after_initial_step(self, fn_map_state, log_probs): log_probs = log_probs[:, -1, :] return log_probs - def _translate_batch_with_strategy(self, batch, decode_strategy): + def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True): """Translate a batch of sentences step by step using cache. Args: @@ -1081,7 +1070,12 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): src = batch["src"] src_len = batch["srclen"] - src, src_len, target_prefix = self.split_src_to_prevent_padding(src, src_len) + if left_pad: + target_prefix = None + else: + src, src_len, target_prefix = self.split_src_to_prevent_padding( + src, src_len + ) # (2) init decoder self.model.decoder.init_state(src, None, None) @@ -1109,7 +1103,6 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): decoder_input = ( src if step == 0 else decode_strategy.current_predictions.view(-1, 1, 1) ) - log_probs, attn = self._decode_and_generate( decoder_input, None,