diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 8b4c6389..19cdfe29 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -71,7 +71,7 @@ auto extRmsNormBackward(const at::Tensor& input, const at::Tensor& grad_output, std::move(grad_bias)); } -void extApplyRotary(at::Tensor output, const at::Tensor& input, +void extApplyRotary(at::Tensor& output, const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin, const bool conj, const bool interleaved) { callDiopi(diopiRotaryEmbedding, output, input, cos, sin, conj, interleaved); @@ -234,16 +234,6 @@ auto extRmsNormLightllm(const at::Tensor& x, const at::Tensor& weight, return out; } -// For lightllm, rotary_embedding reuses the diopi implementation of internlm -void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) { - auto seq_len = q.size(0); - auto dim = q.size(-1); - auto cos_view = cos.view({seq_len, 1, dim / 2}); - auto sin_view = sin.view({seq_len, 1, dim / 2}); - callDiopi(diopiRotaryEmbedding, q, q, cos_view, sin_view, /*conj=*/false, - /*interleaved=*/false); -} - // 判断是否有对应的 diopi 实现: // 如果有, 则直接 pybind 上去; // 否则不注册, 等到 python 层处理. @@ -261,7 +251,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { } if (&diopiRotaryEmbedding != nullptr) { m.def("apply_rotary", &extApplyRotary, "deeplink ext_apply_rotary"); - m.def("rotary_emb", &extRotaryEmb, "deeplink ext_rotary_emb for lightllm"); } if (&diopiMultiHeadAttention != nullptr) { m.def("mha_fwd", &extMultiHeadAttention, "deeplink ext_mha_fwd"); diff --git a/deeplink_ext/ascend_speed/__init__.py b/deeplink_ext/ascend_speed/__init__.py new file mode 100644 index 00000000..399f9138 --- /dev/null +++ b/deeplink_ext/ascend_speed/__init__.py @@ -0,0 +1,3 @@ +from .rotary_embedding import apply_rotary, RotaryEmbedding + +__all__ = ["apply_rotary", "RotaryEmbedding"] diff --git a/deeplink_ext/ascend_speed/rotary_embedding.py b/deeplink_ext/ascend_speed/rotary_embedding.py new file mode 100644 index 00000000..63d78673 --- /dev/null +++ b/deeplink_ext/ascend_speed/rotary_embedding.py @@ -0,0 +1,31 @@ +import torch +from typing import Optional, Union +import deeplink_ext.cpp_extensions as ext + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + output = torch.empty_like(x) + ext.apply_rotary(output, x, cos, sin, conjugate, interleaved) + return output + + +class RotaryEmbedding(torch.autograd.Function): + @staticmethod + def forward(ctx, t, cos, sin): + ctx.save_for_backward(cos, sin) + return apply_rotary(t, cos, sin) + + @staticmethod + def backward(ctx, t): + cos, sin = ctx.saved_tensors + return apply_rotary(t, cos, sin, conjugate=True), None, None diff --git a/deeplink_ext/patch_lightllm.py b/deeplink_ext/patch_lightllm.py index d371870a..009868be 100644 --- a/deeplink_ext/patch_lightllm.py +++ b/deeplink_ext/patch_lightllm.py @@ -54,7 +54,14 @@ def patch_rms_norm_lightllm(): rms_norm_pack.rmsnorm_forward = ext.rms_norm_lightllm def patch_rotary_emb(): - rotary_emb_pack.rotary_emb_fwd = ext.rotary_emb + def rotary_emb(q, cos, sin): + seq_len = q.shape[0] + dim = q.shape[-1] + cos_view = cos.view([seq_len, 1, dim / 2]) + sin_view = sin.view([seq_len, 1, dim / 2]) + ext.apply_rotary(q, q, cos_view, sin_view, False, False) + + rotary_emb_pack.rotary_emb_fwd = rotary_emb try: locals()[f"patch_{op}"]() diff --git a/tests/test_rotary_emb_internlm.py b/tests/test_rotary_emb_internlm.py index b84ea165..e212bd82 100644 --- a/tests/test_rotary_emb_internlm.py +++ b/tests/test_rotary_emb_internlm.py @@ -4,7 +4,7 @@ import deeplink_ext.internlm_ops.rotary as ext -def RotaryEmbTest() -> bool: +def RotaryEmbTestFloat16() -> bool: input = torch.randn(1, 125, 16, 32, dtype=torch.float16).cuda() cos = torch.randn(217, 16, dtype=torch.float16).cuda() @@ -18,7 +18,26 @@ def RotaryEmbTest() -> bool: ) res2 = ext.apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace) + # there is a little calculated error with ascend when dtype is float16 + return torch.allclose(res1, res2, atol=1e-2, rtol=1e-3) + + +def RotaryEmbTestFloat32() -> bool: + input = torch.randn(1, 125, 16, 32, dtype=torch.float32).cuda() + + cos = torch.randn(217, 16, dtype=torch.float32).cuda() + sin = torch.randn(217, 16, dtype=torch.float32).cuda() + input1 = input.detach().clone() + inplace = True + interleaved = False + + res1 = ext.fallback.apply_rotary( + input, cos, sin, interleaved=interleaved, inplace=inplace + ) + res2 = ext.apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace) + return torch.allclose(res1, res2) -assert RotaryEmbTest() +assert RotaryEmbTestFloat32() +assert RotaryEmbTestFloat16()