From 9250a5aece32639d9ce2c06e175ad875670ac36b Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Thu, 22 Feb 2024 09:56:45 +0100 Subject: [PATCH] some code cleaning --- onmt/modules/multi_headed_attn.py | 34 ++++++++++++++++--------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index ac9340b666..b92d3c380d 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -18,13 +18,15 @@ # are both < 2048 tokens. -def rotaryembeddings(dim: int, maxseqlen=2048, base=10000): +def rotaryembeddings(dim: int, maxseqlen=2048, base=10000, device=None): inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) tmax = torch.arange(maxseqlen, device=inv_freq.device) rope = torch.outer(tmax, inv_freq).float() # rope is now matrix [maxseqlen, dim/2] rope = torch.polar(torch.ones_like(rope), rope) rope = torch.cat((rope, rope), dim=1) + if device is not None: + rope = rope.to(device) cos = rope[:, : rope.size(1) // 2].real.contiguous().half() sin = rope[:, : rope.size(1) // 2].imag.contiguous().half() return rope, cos, sin @@ -468,8 +470,8 @@ def forward( self.rotary_dim, maxseqlen=(seqlen + 2048), base=self.rotary_theta, + device=self.rope.device, ) - self.rope = self.rope.to(self.rope.device) rope = self.rope[start_pos : start_pos + seqlen] query, key = apply_rotary_emb( query, key, rope, interleave=self.rotary_interleave @@ -511,18 +513,19 @@ def forward( ], dim=-2, ) - if self.max_relative_positions == -1: # Rotary + if ( + self.max_relative_positions == -1 + and start_pos + 32 >= self.rope.size(0) + ): # Resize rotary embeddings. - if start_pos + 32 >= self.rope.size(0): - # We take a margin of 32 tokens as the kv_cache - # is incremented by 32 tokens every 32 tokens - - self.rope, self.cos, self.sin = rotaryembeddings( - self.rotary_dim, - maxseqlen=(start_pos + 2048), - base=self.rotary_theta, - ) - self.rope = self.rope.to(self.rope.device) + # We take a margin of 32 tokens as the kv_cache + # is incremented by 32 tokens every 32 tokens. + self.rope, self.cos, self.sin = rotaryembeddings( + self.rotary_dim, + maxseqlen=(start_pos + 2048), + base=self.rotary_theta, + device=self.rope.device, + ) if sliding_window > 0 and key.size(2) > sliding_window: self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][ @@ -531,8 +534,6 @@ def forward( self.layer_cache[1]["values"] = self.layer_cache[1]["values"][ :, :, 1:, : ] - self.cos = self.cos.to(query.device) - self.sin = self.sin.to(query.device) context = self.flash_attn_with_kvcache( query.transpose(1, 2), self.layer_cache[1]["keys"].transpose(1, 2), @@ -596,8 +597,9 @@ def forward( self.rotary_dim, maxseqlen=(seqlen + 2048), base=self.rotary_theta, + device=query.device, ) - rope = self.rope[start_pos : start_pos + seqlen].to(query.device) + rope = self.rope[start_pos : start_pos + seqlen] query, key = apply_rotary_emb( query, key, rope, interleave=self.rotary_interleave )