-
Notifications
You must be signed in to change notification settings - Fork 380
ROPE 实现有误 #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
PS:torch 版本跟 transformers 中的 ----------------------------------------------------------------------------------------------------
M=8192, N=1024
----------------------------------------------------------------------------------------------------
out_f32: ['-1.08793437 ', '-1.92047811 ', '-2.2731328 '], time:0.018716ms
tensor([[-1.0879, -1.9205, -2.2731, ..., 0.6204, 0.9124, -1.4198],
[-1.0231, -0.3712, -1.1853, ..., -1.8487, 0.1226, -1.0530],
[-0.5046, -1.3417, 1.2834, ..., 1.0105, 0.0417, 1.0380],
...,
[-0.0632, -0.1381, 0.5857, ..., -0.1382, -0.8494, -1.7711],
[ 0.5749, 0.5440, -0.3250, ..., 2.1607, -1.0835, -1.0153],
[-0.3775, 0.5131, 0.6802, ..., 1.6858, -0.7181, -0.4846]],
device='cuda:0')
out_f32x4_pack: ['-1.08793437 ', '-1.92047811 ', '-2.2731328 '], time:0.018263ms
tensor([[-1.0879, -1.9205, -2.2731, ..., 0.6204, 0.9124, -1.4198],
[-1.0231, -0.3712, -1.1853, ..., -1.8487, 0.1226, -1.0530],
[-0.5046, -1.3417, 1.2834, ..., 1.0105, 0.0417, 1.0380],
...,
[-0.0632, -0.1381, 0.5857, ..., -0.1382, -0.8494, -1.7711],
[ 0.5749, 0.5440, -0.3250, ..., 2.1607, -1.0835, -1.0153],
[-0.3775, 0.5131, 0.6802, ..., 1.6858, -0.7181, -0.4846]],
device='cuda:0')
out_f32_th: ['-1.08793437 ', '-1.92047811 ', '-2.2731328 '], time:1.037562ms
tensor([[-1.0879, -1.9205, -2.2731, ..., 0.6204, 0.9124, -1.4198],
[-1.0231, -0.3712, -1.1902, ..., -0.9135, -0.8198, -0.6722],
[-0.5046, -1.3417, 1.3063, ..., -0.3342, 0.9266, -0.4697],
...,
[-0.0631, -0.1382, 0.6218, ..., -0.2877, -1.9617, 0.1009],
[ 0.5746, 0.5443, -0.3522, ..., -1.3872, -0.2200, 1.4685],
[-0.3778, 0.5129, 1.5211, ..., -1.6983, 0.7339, 0.4603]],
device='cuda:0')
llama-rotary:
tensor([[[[-1.0879, -1.9205, -2.2731, ..., 0.6204, 0.9124, -1.4198],
[ 0.6609, 0.8694, -0.8100, ..., -0.9133, -0.8200, -0.6723],
[ 0.3740, 0.5371, 0.0529, ..., -0.3344, 0.9264, -0.4701],
...,
[-0.1139, 1.0056, -1.3266, ..., 1.0161, -1.3343, 1.5228],
[ 0.5761, 2.1631, 0.2380, ..., -0.6242, -0.0753, 0.9082],
[ 0.2910, 1.1974, 1.6174, ..., -0.9288, -0.8432, -0.7583]]]],
device='cuda:0') class LlamaRotaryEmbedding(nn.Module):
def __init__(self, base, dim, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
return q_embed
...
rotary = LlamaRotaryEmbedding(10000.0, N)
position_ids = torch.arange(0, M, dtype=torch.int32, device=x.device).reshape((1, -1))
cos, sin = rotary(x, position_ids)
pos_emb = apply_rotary_pos_emb(x, cos, sin)
print(pos_emb) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
如题,我发现 rope 的 CUDA 版本和 torch 版本存在差异,正确结果应当是 torch 版本。
打开
show_all
参数后结果如下。可以看到,出了第一行,其它的 embedding 中仅有前两个 value 相同。cpp 版本为
sin = posid / (1.0 / theta) = posid * theta
,而 python 版本为sin = posid * (1.0 / theta) = posid / theta
。当 posid = 0 时,cos = 1 并且 sin = 0,所以不受影响。The text was updated successfully, but these errors were encountered: