diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index b92d3c380d..1a09567e2f 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -464,11 +464,11 @@ def forward( or query.dtype != torch.float16 ): if self.max_relative_positions == -1: # Rotary Embeddings - if seqlen > self.rope.size(0): - + if seqlen + start_pos > self.rope.size(0): + # Resize rotary embeddings. self.rope, _, _ = rotaryembeddings( self.rotary_dim, - maxseqlen=(seqlen + 2048), + maxseqlen=(seqlen + start_pos + 2048), base=self.rotary_theta, device=self.rope.device, )