Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: reimpl rotary embedding for npu #127

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deeplink_ext/ascend_speed/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
from ._rotary_embedding_npu import RotaryEmbedding
# from ._rotary_embedding_npu import RotaryEmbedding
from .rotary_embedding_fallback import RotaryEmbeddingTorch as RotaryEmbedding
elif platform_type == PlatformType.TORCH_DIPU:
from ._rotary_embedding_dipu import RotaryEmbedding
else:
Expand Down
72 changes: 18 additions & 54 deletions deeplink_ext/internevo_ops/_rotary_embedding_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,6 @@
__all__ = ["ApplyRotaryEmb"]


def _unsqueeze_to_4d(x: torch.Tensor):
while x.dim() < 4:
x = x.unsqueeze(0)
return x


def _apply_rotary(x: torch.Tensor, cos, sin, confj, interleaved):
assert interleaved == False, "interleaved not support by torch_npu"

x_view = _unsqueeze_to_4d(x)
cos_view = _unsqueeze_to_4d(cos)
sin_view = _unsqueeze_to_4d(sin)

cos_cat = torch.cat([cos_view, cos_view], -1)
sin_cat = torch.cat([sin_view, sin_view], -1)

if confj:
sin_cat.neg_()

x_view_chunks = x_view.chunk(2, -1)
x_view_new = torch.cat([-x_view_chunks[1], x_view_chunks[0]], -1)

cos_x = torch.mul(cos_cat, x_view)
sin_x = torch.mul(sin_cat, x_view_new)
out = cos_x + sin_x

return out


# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35
class ApplyRotaryEmb(torch.autograd.Function):
"""
Expand Down Expand Up @@ -67,45 +38,38 @@ def forward(
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)

out = _apply_rotary(
x[..., :rotary_dim],
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
)
re_cos = rearrange(cos[:seqlen], "s d -> s 1 d")
re_sin = rearrange(sin[:seqlen], "s d -> s 1 d")

cat_cos = torch.cat([re_cos, re_cos], -1)
cat_sin = torch.cat([re_sin, re_sin], -1)

ctx.save_for_backward(cos, sin)
rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin)
ctx.save_for_backward(cat_cos, cat_sin)
ctx.interleaved = interleaved
ctx.in_place = in_place

if in_place:
x[..., :rotary_dim].copy_(out[..., :rotary_dim])
x[..., :rotary_dim].copy_(rot)
return x
else:
if rotary_dim < head_dim:
out = x.detach().clone()
if rotary_dim < head_dim and not in_place:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
return out

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
cat_cos, cat_sin = ctx.saved_tensors
*_, seqlen, _, head_dim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
rotary_dim = cat_cos.shape[-1]

out = _apply_rotary(
do[..., :rotary_dim],
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
ctx.interleaved,
dx_out = torch_npu.npu_rotary_mul(
do[..., :rotary_dim], cat_cos, torch.neg(cat_sin)
)

if ctx.in_place:
do[..., :rotary_dim].copy_(out[..., :rotary_dim])
do[..., :rotary_dim].copy_(dx_out)
return do, None, None, None, None
else:
if rotary_dim < head_dim:
out[..., rotary_dim:].copy(do[..., rotary_dim:])
return out, None, None, None, None
dx = do.detach().clone()
dx[..., :rotary_dim].copy_(dx_out)
return dx, None, None, None, None
3 changes: 2 additions & 1 deletion deeplink_ext/internevo_ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
from ._rotary_embedding_npu import ApplyRotaryEmb
# from ._rotary_embedding_npu import ApplyRotaryEmb
from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb
elif platform_type == PlatformType.TORCH_DIPU:
from ._rotary_embedding_dipu import ApplyRotaryEmb
else:
Expand Down
127 changes: 41 additions & 86 deletions deeplink_ext/interntrain_ops/_rotary_embedding_npu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) 2024, DeepLink.
# Copyright (c) 2024, InternEvo.

import torch
import torch_npu
Expand All @@ -8,38 +7,16 @@
__all__ = ["ApplyRotaryEmb", "ApplyRotaryEmbQKV_"]


def _unsqueeze_to_4d(x: torch.Tensor):
while x.dim() < 4:
x = x.unsqueeze(0)
return x


def _apply_rotary(x: torch.Tensor, cos, sin, confj, interleaved):
assert interleaved == False, "interleaved not support by torch_npu"

x_view = _unsqueeze_to_4d(x)
cos_view = _unsqueeze_to_4d(cos)
sin_view = _unsqueeze_to_4d(sin)

cos_cat = torch.cat([cos_view, cos_view], -1)
sin_cat = torch.cat([sin_view, sin_view], -1)

if confj:
sin_cat.neg_()

x_view_chunks = x_view.chunk(2, -1)
x_view_new = torch.cat([-x_view_chunks[1], x_view_chunks[0]], -1)

cos_x = torch.mul(cos_cat, x_view)
sin_x = torch.mul(sin_cat, x_view_new)
out = cos_x + sin_x

return out


class ApplyRotaryEmb(torch.autograd.Function):
"""
ApplyRotaryEmb
Apply rotary positional embedding to input tensor x.
Args:
x (Tensor): Input tensor x is of shape [seq_length, ... , dim]
cos (Tensor): Input tensor cos is of shape [seq_length, ..., dim]
sin (Tensor): Input tensor sin is of shape [seq_length, ..., dim]

Returns:
Tensor: The input tensor after applying RoPE
"""

@staticmethod
Expand All @@ -59,34 +36,34 @@ def forward(ctx, x, cos, sin, interleaved=False):
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
out = torch.empty_like(x)

re_cos = rearrange(cos[:seqlen], "s d -> s 1 d")
re_sin = rearrange(sin[:seqlen], "s d -> s 1 d")
out = _apply_rotary(
x[..., :rotary_dim],
re_cos,
re_sin,
False,
interleaved,
)

cat_cos = torch.cat([re_cos, re_cos], -1)
cat_sin = torch.cat([re_sin, re_sin], -1)

rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin)
out[..., :rotary_dim].copy_(rot)
if rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(re_cos, re_sin)

ctx.save_for_backward(cat_cos, cat_sin)
ctx.interleaved = interleaved
return out

@staticmethod
def backward(ctx, do):
re_cos, re_sin = ctx.saved_tensors
cat_cos, cat_sin = ctx.saved_tensors
headdim = do.shape[-1]
rotary_dim = re_cos.shape[-1]
rotary_dim *= 2
dx = _apply_rotary(
do[..., :rotary_dim],
re_cos,
re_sin,
True,
ctx.interleaved,
rotary_dim = cat_cos.shape[-1]

dx = torch.empty_like(do)
dx_rot = torch_npu.npu_rotary_mul(
do[..., :rotary_dim], cat_cos, torch.neg(cat_sin)
)
dx.copy_(dx_rot)

if rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None
Expand Down Expand Up @@ -141,16 +118,10 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
if len(qkv.shape) == 4
else rearrange(sin[:seqlen], "s d -> s 1 d")
)

# qro
out = _apply_rotary(
q_ro,
re_cos,
re_sin,
False,
interleaved,
)
q_ro.copy_(out)
cat_cos = torch.cat([re_cos, re_cos], -1)
cat_sin = torch.cat([re_sin, re_sin], -1)
q_out = torch_npu.npu_rotary_mul(q_ro, cat_cos, cat_sin)
q_ro.copy_(q_out)

k_ro = (
qkv[:, 1, :, :rotary_dim]
Expand All @@ -167,50 +138,34 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
if len(qkv.shape) == 4
else rearrange(sin_k[:seqlen], "s d -> s 1 d")
)
out = _apply_rotary(
k_ro,
re_cos_k,
re_sin_k,
False,
interleaved,
)
k_ro.copy_(out)
cat_cos_k = torch.cat([re_cos_k, re_cos_k], -1)
cat_sin_k = torch.cat([re_sin_k, re_sin_k], -1)
k_out = torch_npu.npu_rotary_mul(k_ro, cat_cos_k, cat_sin_k)
k_ro.copy_(k_out)

ctx.save_for_backward(re_cos, re_sin, re_cos_k, re_sin_k)
ctx.save_for_backward(cat_cos, cat_sin, cat_cos_k, cat_sin_k)
ctx.interleaved = interleaved
return qkv

@staticmethod
def backward(ctx, dqkv):
re_cos, re_sin, re_cos_k, re_sin_k = ctx.saved_tensors
rotary_dim = re_cos.shape[-1]
rotary_dim *= 2
cat_cos, cat_sin, cat_cos_k, cat_sin_k = ctx.saved_tensors
rotary_dim = cat_cos.shape[-1]

dq_ro = (
dqkv[:, 0, :, :rotary_dim]
if len(dqkv.shape) == 4
else dqkv[:, :, 0, :, :rotary_dim]
)
out = _apply_rotary(
dq_ro,
re_cos,
re_sin,
True,
ctx.interleaved,
)
dq_ro.copy_(out)
dq_out = torch_npu.npu_rotary_mul(dq_ro, cat_cos, torch.neg(cat_sin))
dq_ro.copy_(dq_out)

dk_ro = (
dqkv[:, 1, :, :rotary_dim]
if len(dqkv.shape) == 4
else dqkv[:, :, 1, :, :rotary_dim]
)
out = _apply_rotary(
dk_ro,
re_cos_k,
re_sin_k,
True,
ctx.interleaved,
)
dk_ro.copy_(out)
dk_out = torch_npu.npu_rotary_mul(dk_ro, cat_cos_k, torch.neg(cat_sin_k))
dk_ro.copy_(dk_out)

return dqkv, None, None, None, None, None
3 changes: 2 additions & 1 deletion deeplink_ext/interntrain_ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
from ._mixed_rms_norm_npu import MixedFusedRMSNorm
# from ._mixed_rms_norm_npu import MixedFusedRMSNorm
from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm
elif platform_type == PlatformType.TORCH_DIPU:
from ._mixed_rms_norm_dipu import MixedFusedRMSNorm
else:
Expand Down
4 changes: 3 additions & 1 deletion deeplink_ext/interntrain_ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
from ._rotary_embedding_npu import ApplyRotaryEmb, ApplyRotaryEmbQKV_
# from ._rotary_embedding_npu import ApplyRotaryEmb, ApplyRotaryEmbQKV_
from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb
from .rotary_embedding_fallback import ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_
elif platform_type == PlatformType.TORCH_DIPU:
from ._rotary_embedding_dipu import ApplyRotaryEmb, ApplyRotaryEmbQKV_
else:
Expand Down
Loading