Skip to content

Commit

Permalink
refactor(lightllm): rotary emb (#60)
Browse files Browse the repository at this point in the history
move RotaryEmb wrapper for lightllm from cpp to python

---------

Co-authored-by: root <[email protected]>
  • Loading branch information
zhangzefeng92 and yangbofun authored Apr 1, 2024
1 parent c98cb5c commit 778fc8c
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 15 deletions.
13 changes: 1 addition & 12 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ auto extRmsNormBackward(const at::Tensor& input, const at::Tensor& grad_output,
std::move(grad_bias));
}

void extApplyRotary(at::Tensor output, const at::Tensor& input,
void extApplyRotary(at::Tensor& output, const at::Tensor& input,
const at::Tensor& cos, const at::Tensor& sin,
const bool conj, const bool interleaved) {
callDiopi(diopiRotaryEmbedding, output, input, cos, sin, conj, interleaved);
Expand Down Expand Up @@ -234,16 +234,6 @@ auto extRmsNormLightllm(const at::Tensor& x, const at::Tensor& weight,
return out;
}

// For lightllm, rotary_embedding reuses the diopi implementation of internlm
void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) {
auto seq_len = q.size(0);
auto dim = q.size(-1);
auto cos_view = cos.view({seq_len, 1, dim / 2});
auto sin_view = sin.view({seq_len, 1, dim / 2});
callDiopi(diopiRotaryEmbedding, q, q, cos_view, sin_view, /*conj=*/false,
/*interleaved=*/false);
}

// 判断是否有对应的 diopi 实现:
// 如果有, 则直接 pybind 上去;
// 否则不注册, 等到 python 层处理.
Expand All @@ -261,7 +251,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
}
if (&diopiRotaryEmbedding != nullptr) {
m.def("apply_rotary", &extApplyRotary, "deeplink ext_apply_rotary");
m.def("rotary_emb", &extRotaryEmb, "deeplink ext_rotary_emb for lightllm");
}
if (&diopiMultiHeadAttention != nullptr) {
m.def("mha_fwd", &extMultiHeadAttention, "deeplink ext_mha_fwd");
Expand Down
3 changes: 3 additions & 0 deletions deeplink_ext/ascend_speed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .rotary_embedding import apply_rotary, RotaryEmbedding

__all__ = ["apply_rotary", "RotaryEmbedding"]
31 changes: 31 additions & 0 deletions deeplink_ext/ascend_speed/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from typing import Optional, Union
import deeplink_ext.cpp_extensions as ext


def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
output = torch.empty_like(x)
ext.apply_rotary(output, x, cos, sin, conjugate, interleaved)
return output


class RotaryEmbedding(torch.autograd.Function):
@staticmethod
def forward(ctx, t, cos, sin):
ctx.save_for_backward(cos, sin)
return apply_rotary(t, cos, sin)

@staticmethod
def backward(ctx, t):
cos, sin = ctx.saved_tensors
return apply_rotary(t, cos, sin, conjugate=True), None, None
9 changes: 8 additions & 1 deletion deeplink_ext/patch_lightllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@ def patch_rms_norm_lightllm():
rms_norm_pack.rmsnorm_forward = ext.rms_norm_lightllm

def patch_rotary_emb():
rotary_emb_pack.rotary_emb_fwd = ext.rotary_emb
def rotary_emb(q, cos, sin):
seq_len = q.shape[0]
dim = q.shape[-1]
cos_view = cos.view([seq_len, 1, dim / 2])
sin_view = sin.view([seq_len, 1, dim / 2])
ext.apply_rotary(q, q, cos_view, sin_view, False, False)

rotary_emb_pack.rotary_emb_fwd = rotary_emb

try:
locals()[f"patch_{op}"]()
Expand Down
23 changes: 21 additions & 2 deletions tests/test_rotary_emb_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import deeplink_ext.internlm_ops.rotary as ext


def RotaryEmbTest() -> bool:
def RotaryEmbTestFloat16() -> bool:
input = torch.randn(1, 125, 16, 32, dtype=torch.float16).cuda()

cos = torch.randn(217, 16, dtype=torch.float16).cuda()
Expand All @@ -18,7 +18,26 @@ def RotaryEmbTest() -> bool:
)
res2 = ext.apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace)

# there is a little calculated error with ascend when dtype is float16
return torch.allclose(res1, res2, atol=1e-2, rtol=1e-3)


def RotaryEmbTestFloat32() -> bool:
input = torch.randn(1, 125, 16, 32, dtype=torch.float32).cuda()

cos = torch.randn(217, 16, dtype=torch.float32).cuda()
sin = torch.randn(217, 16, dtype=torch.float32).cuda()
input1 = input.detach().clone()
inplace = True
interleaved = False

res1 = ext.fallback.apply_rotary(
input, cos, sin, interleaved=interleaved, inplace=inplace
)
res2 = ext.apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace)

return torch.allclose(res1, res2)


assert RotaryEmbTest()
assert RotaryEmbTestFloat32()
assert RotaryEmbTestFloat16()

0 comments on commit 778fc8c

Please sign in to comment.