Skip to content

ROPE 实现有误 #289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
hebangwen opened this issue Apr 15, 2025 · 1 comment
Closed

ROPE 实现有误 #289

hebangwen opened this issue Apr 15, 2025 · 1 comment
Assignees

Comments

@hebangwen
Copy link

如题,我发现 rope 的 CUDA 版本和 torch 版本存在差异,正确结果应当是 torch 版本

打开 show_all 参数后结果如下。可以看到,出了第一行,其它的 embedding 中仅有前两个 value 相同。

----------------------------------------------------------------------------------------------------
                                        M=4096, N=512
----------------------------------------------------------------------------------------------------
out_f32: ['1.87628162  ', '0.54061633  ', '-0.25348133 '], time:0.006235ms
tensor([[ 1.8763,  0.5406, -0.2535,  ...,  0.0888, -1.4432, -0.6644],
        [ 0.1707,  0.0664, -0.6512,  ...,  1.1742,  1.0595,  0.9371],
        [-0.7242,  0.9565, -2.6008,  ...,  1.2142,  0.0903, -1.2440],
        ...,
        [ 0.2749, -0.8864,  0.4484,  ...,  1.4880, -1.5875, -0.4284],
        [-1.1559, -0.1207, -0.7012,  ..., -2.5590,  2.8657,  1.8899],
        [ 0.2627, -1.0124,  2.0098,  ...,  1.0261, -0.7263,  0.7020]],
       device='cuda:0')
out_f32x4_pack: ['1.87628162  ', '0.54061633  ', '-0.25348133 '], time:0.005949ms
tensor([[ 1.8763,  0.5406, -0.2535,  ...,  0.0888, -1.4432, -0.6644],
        [ 0.1707,  0.0664, -0.6512,  ...,  1.1742,  1.0595,  0.9371],
        [-0.7242,  0.9565, -2.6008,  ...,  1.2142,  0.0903, -1.2440],
        ...,
        [ 0.2749, -0.8864,  0.4484,  ...,  1.4880, -1.5875, -0.4284],
        [-1.1559, -0.1207, -0.7012,  ..., -2.5590,  2.8657,  1.8899],
        [ 0.2627, -1.0124,  2.0098,  ...,  1.0261, -0.7263,  0.7020]],
       device='cuda:0')
out_f32_th: ['1.87628162  ', '0.54061633  ', '-0.25348133 '], time:0.389457ms
tensor([[ 1.8763,  0.5406, -0.2535,  ...,  0.0888, -1.4432, -0.6644],
        [ 0.1707,  0.0664, -0.6292,  ...,  0.4374,  1.3610, -0.3851],
        [-0.7242,  0.9565, -2.6385,  ..., -1.7485, -1.1688,  0.4353],
        ...,
        [ 0.2751, -0.8863,  0.5583,  ..., -1.1978,  0.6213,  1.5224],
        [-1.1558, -0.1212, -0.5207,  ...,  2.6085, -3.0074, -1.6552],
        [ 0.2630, -1.0123,  1.8745,  ..., -1.7671, -0.2767, -0.9714]],
       device='cuda:0')
----------------------------------------------------------------------------------------------------
                                        M=4096, N=1024
----------------------------------------------------------------------------------------------------
out_f32: ['1.27092159  ', '-0.83720589 ', '-0.53293085 '], time:0.010192ms
tensor([[ 1.2709, -0.8372, -0.5329,  ...,  0.0735,  0.0132, -1.6987],
        [ 0.9150, -1.0800,  0.1529,  ..., -0.2318, -0.6195, -0.2894],
        [ 0.4141,  1.5823, -1.1214,  ..., -0.9526,  0.1026,  0.2473],
        ...,
        [-0.7135, -0.6693, -1.0882,  ...,  0.4784,  1.8427,  0.6524],
        [ 0.0140,  0.0685,  0.2967,  ...,  0.6755,  2.2026, -2.7582],
        [ 0.3151, -0.7361,  1.1072,  ...,  1.2773,  0.4074,  2.4871]],
       device='cuda:0')
out_f32x4_pack: ['1.27092159  ', '-0.83720589 ', '-0.53293085 '], time:0.009954ms
tensor([[ 1.2709, -0.8372, -0.5329,  ...,  0.0735,  0.0132, -1.6987],
        [ 0.9150, -1.0800,  0.1529,  ..., -0.2318, -0.6195, -0.2894],
        [ 0.4141,  1.5823, -1.1214,  ..., -0.9526,  0.1026,  0.2473],
        ...,
        [-0.7135, -0.6693, -1.0882,  ...,  0.4784,  1.8427,  0.6524],
        [ 0.0140,  0.0685,  0.2967,  ...,  0.6755,  2.2026, -2.7582],
        [ 0.3151, -0.7361,  1.1072,  ...,  1.2773,  0.4074,  2.4871]],
       device='cuda:0')
out_f32_th: ['1.27092159  ', '-0.83720589 ', '-0.53293085 '], time:0.605023ms
tensor([[ 1.2709, -0.8372, -0.5329,  ...,  0.0735,  0.0132, -1.6987],
        [ 0.9150, -1.0800,  0.1493,  ..., -1.2779, -0.5782,  0.3649],
        [ 0.4141,  1.5823, -1.0902,  ...,  0.1315,  0.1822, -0.1962],
        ...,
        [-0.7134, -0.6695, -0.1875,  ...,  0.2179, -0.6118, -1.8566],
        [ 0.0140,  0.0685, -0.4804,  ..., -0.7615, -1.9533,  2.9400],
        [ 0.3153, -0.7360,  0.2607,  ...,  0.4753, -2.3916, -0.7950]],
       device='cuda:0')
----------------------------------------------------------------------------------------------------
                                        M=8192, N=512
----------------------------------------------------------------------------------------------------
out_f32: ['-1.02901423 ', '0.70658654  ', '0.23782901  '], time:0.010240ms
tensor([[-1.0290,  0.7066,  0.2378,  ...,  1.0706,  0.4719, -0.9742],
        [-3.4676,  0.4451, -0.1615,  ...,  0.3343, -0.7377, -0.2053],
        [-0.6065,  0.5367, -0.6062,  ..., -0.8156,  0.0576, -0.6682],
        ...,
        [ 0.6828,  1.4610,  0.2043,  ..., -1.0048, -1.3309, -1.7609],
        [-0.5199, -0.6188, -0.8065,  ...,  1.5828,  1.1549, -0.1436],
        [-0.1906, -1.3899, -2.2631,  ..., -0.2651, -0.2952,  0.4099]],
       device='cuda:0')
out_f32x4_pack: ['-1.02901423 ', '0.70658654  ', '0.23782901  '], time:0.009990ms
tensor([[-1.0290,  0.7066,  0.2378,  ...,  1.0706,  0.4719, -0.9742],
        [-3.4676,  0.4451, -0.1615,  ...,  0.3343, -0.7377, -0.2053],
        [-0.6065,  0.5367, -0.6062,  ..., -0.8156,  0.0576, -0.6682],
        ...,
        [ 0.6828,  1.4610,  0.2043,  ..., -1.0048, -1.3309, -1.7609],
        [-0.5199, -0.6188, -0.8065,  ...,  1.5828,  1.1549, -0.1436],
        [-0.1906, -1.3899, -2.2631,  ..., -0.2651, -0.2952,  0.4099]],
       device='cuda:0')
out_f32_th: ['-1.02901423 ', '0.70658654  ', '0.23782901  '], time:0.592649ms
tensor([[-1.0290,  0.7066,  0.2378,  ...,  1.0706,  0.4719, -0.9742],
        [-3.4676,  0.4451, -0.1410,  ...,  1.1787, -0.5714,  0.5098],
        [-0.6065,  0.5367, -0.7179,  ...,  0.7290, -0.6316,  0.2256],
        ...,
        [ 0.6817,  1.4616,  0.3496,  ..., -1.2603, -2.1459,  0.5166],
        [-0.5195, -0.6192, -1.0658,  ...,  0.9454, -0.7611, -0.8805],
        [-0.1900, -1.3900, -1.8335,  ...,  0.2775,  0.2876, -0.4153]],
       device='cuda:0')
----------------------------------------------------------------------------------------------------
                                        M=8192, N=1024
----------------------------------------------------------------------------------------------------
out_f32: ['-0.20427416 ', '-0.43109831 ', '-1.72741675 '], time:0.018680ms
tensor([[-0.2043, -0.4311, -1.7274,  ...,  0.5081, -1.0143,  1.3941],
        [-0.5052, -0.1664,  0.6081,  ..., -0.2041,  1.0397,  0.5240],
        [ 0.4711,  1.1454, -0.2527,  ..., -0.3829,  0.5055,  0.1488],
        ...,
        [ 1.6721,  2.1987,  0.5756,  ..., -0.3784,  0.6181,  0.6599],
        [ 0.5099, -1.6486,  0.5099,  ..., -0.6704,  0.5207, -1.8142],
        [ 0.5882,  2.8212,  1.3382,  ..., -0.6798, -0.7977,  0.7392]],
       device='cuda:0')
out_f32x4_pack: ['-0.20427416 ', '-0.43109831 ', '-1.72741675 '], time:0.018334ms
tensor([[-0.2043, -0.4311, -1.7274,  ...,  0.5081, -1.0143,  1.3941],
        [-0.5052, -0.1664,  0.6081,  ..., -0.2041,  1.0397,  0.5240],
        [ 0.4711,  1.1454, -0.2527,  ..., -0.3829,  0.5055,  0.1488],
        ...,
        [ 1.6721,  2.1987,  0.5756,  ..., -0.3784,  0.6181,  0.6599],
        [ 0.5099, -1.6486,  0.5099,  ..., -0.6704,  0.5207, -1.8142],
        [ 0.5882,  2.8212,  1.3382,  ..., -0.6798, -0.7977,  0.7392]],
       device='cuda:0')
out_f32_th: ['-0.20427416 ', '-0.43109831 ', '-1.72741675 '], time:1.144743ms
tensor([[-0.2043, -0.4311, -1.7274,  ...,  0.5081, -1.0143,  1.3941],
        [-0.5052, -0.1664,  0.6156,  ..., -1.7702,  1.0027, -0.5917],
        [ 0.4711,  1.1454, -0.3125,  ...,  1.6347, -0.0749, -0.5216],
        ...,
        [ 1.6704,  2.2000,  2.7192,  ...,  0.1448,  0.8472, -0.3159],
        [ 0.5110, -1.6483,  0.7006,  ...,  0.9939, -1.7888,  0.6023],
        [ 0.5869,  2.8215,  0.6271,  ...,  0.7010,  0.7725, -0.7655]],
       device='cuda:0')
----------------------------------------------------------------------------------------------------

cpp 版本为 sin = posid / (1.0 / theta) = posid * theta,而 python 版本为 sin = posid * (1.0 / theta) = posid / theta。当 posid = 0 时,cos = 1 并且 sin = 0,所以不受影响。

__global__ void rope_f32_kernel(float* x, float* out, int seq_len, int N){ 
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  float x1 = x[idx * 2];
  float x2 = x[idx * 2 + 1];
  int token_pos = idx / N; 
  int token_idx = idx % N;
  float exp_v = 1.0f / powf(theta, token_idx / (N * 2));
  // 注意 sin 的计算
  float sin_v = sinf(token_pos / exp_v);
  float cos_v = cosf(token_pos / exp_v);
  float out1 = x1 * cos_v - x2 * sin_v;
  float out2 = x1 * sin_v + x2 * cos_v;
  out[idx * 2] = out1;
  out[idx * 2 + 1] = out2;
}
def naive_rope(
    x: torch.Tensor,
    theta: float = 10000.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
    dim = x.shape[-1]
    seq_len = x.shape[-2]
    # get the shape of x (ignore the head dimension). 
    # x: [batch_size, seq_len, dim]
    x_ = x.float().reshape(*x.shape[:-1], -1, 2)
    # x_: [batch_size, seq_len, dim//2, 2]
    x_ = torch.view_as_complex(x_)
    # pack neibored element into a complex
    # x_: [batch_size, seq_len, dim//2, 1]. eg: tensor([(1.6116-0.5772j), ...]
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(seq_len , device=freqs.device)
    freqs = torch.outer(t, freqs).float().cuda()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 
    # get rotate angle
    xq_out = torch.view_as_real(x_ * freqs_cis).flatten(1)
    # do rotate
    return xq_out.type_as(x)
@DefTruth DefTruth assigned DefTruth and bear-zd and unassigned DefTruth Apr 15, 2025
@hebangwen
Copy link
Author

PS:torch 版本跟 transformers 中的 RotaryEmbedding 也是不同的,代码和结果如下。

----------------------------------------------------------------------------------------------------
                                        M=8192, N=1024
----------------------------------------------------------------------------------------------------
             out_f32: ['-1.08793437 ', '-1.92047811 ', '-2.2731328  '], time:0.018716ms
tensor([[-1.0879, -1.9205, -2.2731,  ...,  0.6204,  0.9124, -1.4198],
        [-1.0231, -0.3712, -1.1853,  ..., -1.8487,  0.1226, -1.0530],
        [-0.5046, -1.3417,  1.2834,  ...,  1.0105,  0.0417,  1.0380],
        ...,
        [-0.0632, -0.1381,  0.5857,  ..., -0.1382, -0.8494, -1.7711],
        [ 0.5749,  0.5440, -0.3250,  ...,  2.1607, -1.0835, -1.0153],
        [-0.3775,  0.5131,  0.6802,  ...,  1.6858, -0.7181, -0.4846]],
       device='cuda:0')
      out_f32x4_pack: ['-1.08793437 ', '-1.92047811 ', '-2.2731328  '], time:0.018263ms
tensor([[-1.0879, -1.9205, -2.2731,  ...,  0.6204,  0.9124, -1.4198],
        [-1.0231, -0.3712, -1.1853,  ..., -1.8487,  0.1226, -1.0530],
        [-0.5046, -1.3417,  1.2834,  ...,  1.0105,  0.0417,  1.0380],
        ...,
        [-0.0632, -0.1381,  0.5857,  ..., -0.1382, -0.8494, -1.7711],
        [ 0.5749,  0.5440, -0.3250,  ...,  2.1607, -1.0835, -1.0153],
        [-0.3775,  0.5131,  0.6802,  ...,  1.6858, -0.7181, -0.4846]],
       device='cuda:0')
          out_f32_th: ['-1.08793437 ', '-1.92047811 ', '-2.2731328  '], time:1.037562ms
tensor([[-1.0879, -1.9205, -2.2731,  ...,  0.6204,  0.9124, -1.4198],
        [-1.0231, -0.3712, -1.1902,  ..., -0.9135, -0.8198, -0.6722],
        [-0.5046, -1.3417,  1.3063,  ..., -0.3342,  0.9266, -0.4697],
        ...,
        [-0.0631, -0.1382,  0.6218,  ..., -0.2877, -1.9617,  0.1009],
        [ 0.5746,  0.5443, -0.3522,  ..., -1.3872, -0.2200,  1.4685],
        [-0.3778,  0.5129,  1.5211,  ..., -1.6983,  0.7339,  0.4603]],
       device='cuda:0')
llama-rotary: 
tensor([[[[-1.0879, -1.9205, -2.2731,  ...,  0.6204,  0.9124, -1.4198],
          [ 0.6609,  0.8694, -0.8100,  ..., -0.9133, -0.8200, -0.6723],
          [ 0.3740,  0.5371,  0.0529,  ..., -0.3344,  0.9264, -0.4701],
          ...,
          [-0.1139,  1.0056, -1.3266,  ...,  1.0161, -1.3343,  1.5228],
          [ 0.5761,  2.1631,  0.2380,  ..., -0.6242, -0.0753,  0.9082],
          [ 0.2910,  1.1974,  1.6174,  ..., -0.9288, -0.8432, -0.7583]]]],
       device='cuda:0')

class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, base, dim, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    @torch.no_grad()
    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    return q_embed

...

rotary = LlamaRotaryEmbedding(10000.0, N)
position_ids = torch.arange(0, M, dtype=torch.int32, device=x.device).reshape((1, -1))
cos, sin = rotary(x, position_ids)
pos_emb = apply_rotary_pos_emb(x, cos, sin)
print(pos_emb)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants