From d9d8b77d0027d0d2f9e6d8ac5ff62341eb8a6f98 Mon Sep 17 00:00:00 2001 From: Lina Khodja <57141057+l-k-11235@users.noreply.github.com> Date: Thu, 22 Feb 2024 12:23:14 +0100 Subject: [PATCH] fix generation with large sequences when flash2 is False (#2564) --- onmt/modules/multi_headed_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index b92d3c380d..1a09567e2f 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -464,11 +464,11 @@ def forward( or query.dtype != torch.float16 ): if self.max_relative_positions == -1: # Rotary Embeddings - if seqlen > self.rope.size(0): - + if seqlen + start_pos > self.rope.size(0): + # Resize rotary embeddings. self.rope, _, _ = rotaryembeddings( self.rotary_dim, - maxseqlen=(seqlen + 2048), + maxseqlen=(seqlen + start_pos + 2048), base=self.rotary_theta, device=self.rope.device, )