diff --git a/src/olmo_core/nn/attention.py b/src/olmo_core/nn/attention.py index 096c4302..d7284343 100644 --- a/src/olmo_core/nn/attention.py +++ b/src/olmo_core/nn/attention.py @@ -246,8 +246,8 @@ def forward( # (batch_size, n_kv_heads, seq_len, head_dim) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) - # PyTorch's SDPA doesn't support MQA/GQA, so we have to do this. - if self.n_heads != self.n_kv_heads: + # PyTorch's SDPA doesn't support GQA, so we have to do this. + if self.n_heads != self.n_kv_heads and self.n_kv_heads > 1: k = k.repeat_interleave( self.n_heads // self.n_kv_heads, dim=1, output_size=self.n_heads )