diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 4da2b81d..f3639155 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -394,14 +394,24 @@ def __init__( tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. - self.rotary_embedding = RotaryEmbedding( - dim=self.d_qk, - end=config.max_position_embeddings, - theta=config.rope_theta, - ) + if config.rope_interleaved: + self.rotary_embedding = RotaryEmbedding( + dim=self.d_qk, + end=config.max_position_embeddings, + theta=config.rope_theta, + ) + else: + self.rotary_embedding = LlamaRotaryEmbedding( + dim=self.d_qk, + end=config.max_position_embeddings, + theta=config.rope_theta, + ) + self.rope_interleaved = config.rope_interleaved # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) - self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, base=config.rope_theta, interleaved=True) + self.flash_rotary_embedding = FlashRotaryEmbedding( + dim=self.d_qk, base=config.rope_theta, interleaved=config.rope_interleaved + ) self.o_proj = TensorParallelRowLinear( config.num_attention_heads * self.d_qk, @@ -480,8 +490,14 @@ def forward( # Compute rotary embeddings # Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache old_rotary_embed_end = self.rotary_embedding.end - query_states = self.rotary_embedding(query_states, position_ids=position_ids) - key_states = self.rotary_embedding(key_states, position_ids=position_ids) + # interleaved version. + if self.rope_interleaved: + query_states = self.rotary_embedding(query_states, position_ids=position_ids) + key_states = self.rotary_embedding(key_states, position_ids=position_ids) + # non interleaved version. + else: + cos, sin = self.rotary_embedding(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if "key" not in store: # First inference iteration (Prefill) @@ -620,7 +636,7 @@ def forward( cache_seqlens=position_offsets.contiguous(), softmax_scale=softmax_scale, causal=True, - rotary_interleaved=False, # GPT-NeoX style + rotary_interleaved=False, # the value is not used unless rotary_cos/sin is provided. https://github.com/Dao-AILab/flash-attention ) store.update(