Skip to content

Commit

Permalink
some code cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
l-k-11235 committed Feb 22, 2024
1 parent 3f06486 commit 9250a5a
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
# are both < 2048 tokens.


def rotaryembeddings(dim: int, maxseqlen=2048, base=10000):
def rotaryembeddings(dim: int, maxseqlen=2048, base=10000, device=None):
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
tmax = torch.arange(maxseqlen, device=inv_freq.device)
rope = torch.outer(tmax, inv_freq).float()
# rope is now matrix [maxseqlen, dim/2]
rope = torch.polar(torch.ones_like(rope), rope)
rope = torch.cat((rope, rope), dim=1)
if device is not None:
rope = rope.to(device)
cos = rope[:, : rope.size(1) // 2].real.contiguous().half()
sin = rope[:, : rope.size(1) // 2].imag.contiguous().half()
return rope, cos, sin
Expand Down Expand Up @@ -468,8 +470,8 @@ def forward(
self.rotary_dim,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
device=self.rope.device,
)
self.rope = self.rope.to(self.rope.device)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
Expand Down Expand Up @@ -511,18 +513,19 @@ def forward(
],
dim=-2,
)
if self.max_relative_positions == -1: # Rotary
if (
self.max_relative_positions == -1
and start_pos + 32 >= self.rope.size(0)
):
# Resize rotary embeddings.
if start_pos + 32 >= self.rope.size(0):
# We take a margin of 32 tokens as the kv_cache
# is incremented by 32 tokens every 32 tokens

self.rope, self.cos, self.sin = rotaryembeddings(
self.rotary_dim,
maxseqlen=(start_pos + 2048),
base=self.rotary_theta,
)
self.rope = self.rope.to(self.rope.device)
# We take a margin of 32 tokens as the kv_cache
# is incremented by 32 tokens every 32 tokens.
self.rope, self.cos, self.sin = rotaryembeddings(
self.rotary_dim,
maxseqlen=(start_pos + 2048),
base=self.rotary_theta,
device=self.rope.device,
)

if sliding_window > 0 and key.size(2) > sliding_window:
self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][
Expand All @@ -531,8 +534,6 @@ def forward(
self.layer_cache[1]["values"] = self.layer_cache[1]["values"][
:, :, 1:, :
]
self.cos = self.cos.to(query.device)
self.sin = self.sin.to(query.device)
context = self.flash_attn_with_kvcache(
query.transpose(1, 2),
self.layer_cache[1]["keys"].transpose(1, 2),
Expand Down Expand Up @@ -596,8 +597,9 @@ def forward(
self.rotary_dim,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
device=query.device,
)
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
)
Expand Down

0 comments on commit 9250a5a

Please sign in to comment.