Skip to content

Commit

Permalink
revert attn change
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 30, 2024
1 parent e463768 commit 69326b0
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/olmo_core/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ def forward(
# (batch_size, n_kv_heads, seq_len, head_dim)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

# PyTorch's SDPA doesn't support MQA/GQA, so we have to do this.
if self.n_heads != self.n_kv_heads:
# PyTorch's SDPA doesn't support GQA, so we have to do this.
if self.n_heads != self.n_kv_heads and self.n_kv_heads > 1:
k = k.repeat_interleave(
self.n_heads // self.n_kv_heads, dim=1, output_size=self.n_heads
)
Expand Down

0 comments on commit 69326b0

Please sign in to comment.