diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 93e87430..370b163c 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -197,21 +197,6 @@ void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch); } -// For lightllm, rms_norm reuses the diopi implementation of internlm -auto extRmsNormLightllm(const at::Tensor& x, const at::Tensor& weight, - float eps) { - at::ScalarType acc_type = x.scalar_type(); - if (x.scalar_type() == at::kBFloat16 || x.scalar_type() == at::kHalf) { - acc_type = at::kFloat; - } - auto inv_rms = at::empty_like(x, acc_type); - auto out = at::empty_like(x); - auto bias = at::empty_like(weight); - at::OptionalIntArrayRef normalized_shape = weight.sizes(); - callDiopi(diopiRMSNorm, out, inv_rms, x, normalized_shape, weight, bias, eps); - return out; -} - // For lightllm, rotary_embedding reuses the diopi implementation of internlm void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) { auto seq_len = q.size(0); @@ -229,9 +214,6 @@ void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { if (&diopiRMSNorm != nullptr) { // Check if weak symbol defined m.def("rms_norm", &extRmsNorm, "deeplink ext_rms_norm"); - m.def("rms_norm_lightllm", &extRmsNormLightllm, - "deeplink ext_rms_norm for lightllm", py::arg("x"), py::arg("weight"), - py::arg("eps")); } if (&diopiRMSNormBackward != nullptr) { m.def("rms_norm_backward", &extRmsNormBackward, diff --git a/csrc/pybind_type_cast.h b/csrc/pybind_type_cast.h index 6d128981..5a645d85 100644 --- a/csrc/pybind_type_cast.h +++ b/csrc/pybind_type_cast.h @@ -21,28 +21,5 @@ using OptionalIntArray = c10::optional; } // namespace dipu::dipu_ext -namespace pybind11::detail { - -namespace py = pybind11; - -template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(dipu::dipu_ext::OptionalIntArray, _("OptionalIntArray")); - - bool load(py::handle src, bool /*unused*/) { - if (PyList_Check(src.ptr())) { - value = py::cast(src); - return true; - } - if (src.is_none()) { - value = c10::nullopt; - return true; - } - return false; - } -}; - -} // namespace pybind11::detail #endif /* end of include guard: PYBIND_TYPE_CAST_H_PXMGELYW */ diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index d42208bc..e6841775 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -1,64 +1,110 @@ # Copyright (c) 2024, DeepLink. import torch -import deeplink_ext.cpp_extensions as ext +import deeplink_ext.cpp_extensions as cpp_ext -assert hasattr(ext, "rms_norm") +assert hasattr(cpp_ext, "rms_norm") -# 定义自定义的 autograd 函数 -class _DeepLinkRMSNormFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, hidden_states, weight, bias, eps): - output = torch.empty_like(hidden_states) - inv_rms_shape = list(hidden_states.shape[:-1]) + [1] - inv_rms = torch.empty( - inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device - ) - ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, bias, eps) +def rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps): + if None == normalized_shape: + cpp_ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, eps) + else: + cpp_ext.rms_norm(output, inv_rms, input, normalized_shape, weight, bias, eps) - ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) - return output - @staticmethod - def backward(ctx, grad_output): - hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors - eps = eps_tensor.item() +def rms_norm(input, normalized_shape, weight, bias, eps): + output = torch.empty_like(input) + inv_rms_shape = list(input.shape[:-1]) + [1] + inv_rms = torch.empty(inv_rms_shape, dtype=input.dtype, device=input.device) + rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps) + + return [output, inv_rms] - grad_input = torch.empty_like(hidden_states) - grad_weight = torch.empty_like(weight) - grad_bias = torch.empty_like(bias) - ext.rms_norm_backward( + +def rms_norm_backward_out( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, +): + if None == normalized_shape: + cpp_ext.rms_norm_backward( grad_input, grad_weight, grad_bias, grad_output, - hidden_states, + input, weight, bias, inv_rms, weight.shape, eps, ) + else: + cpp_ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, + ) + + +def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bias, eps): + grad_input = torch.empty_like(input) + grad_weight = torch.empty_like(weight) + grad_bias = torch.empty_like(bias) + rms_norm_backward_out( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, + ) + + return [grad_input, grad_weight, grad_bias] + + +# 定义自定义的 autograd 函数 +class _DeepLinkRMSNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, hidden_states, weight, bias, eps): + output, inv_rms = rms_norm(hidden_states, None, weight, bias, eps) + ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) + return output + + @staticmethod + def backward(ctx, grad_output): + hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors + eps = eps_tensor.item() + grad_input, grad_weight, grad_bias = rms_norm_backward( + hidden_states, grad_output, inv_rms, None, weight, bias, eps + ) return grad_input, grad_weight, grad_bias, None class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): - output = torch.empty_like(hidden_states, dtype=torch.float32) - inv_rms_shape = list(hidden_states.shape[:-1]) + [1] - inv_rms = torch.empty( - inv_rms_shape, dtype=torch.float32, device=hidden_states.device - ) - ext.rms_norm( - output, - inv_rms, - hidden_states.float(), - normalized_shape, - weight.float(), - bias.float(), - eps, + output, inv_rms = rms_norm( + hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps ) output = output.half() inv_rms = inv_rms.half() @@ -72,24 +118,15 @@ def backward(ctx, grad_output): eps = eps_tensor.item() normalized_shape = ctx.intermediate_results - grad_input = torch.empty_like(hidden_states, dtype=torch.float32) - grad_weight = torch.empty_like(weight, dtype=torch.float32) - grad_bias = torch.empty_like(bias, dtype=torch.float32) - ext.rms_norm_backward( - grad_input, - grad_weight, - grad_bias, - grad_output.float(), + grad_input, grad_weight, grad_bias = rms_norm_backward( hidden_states.float(), - weight.float(), - bias.float(), + grad_output.float(), inv_rms.float(), normalized_shape, + weight.float(), + bias.float(), eps, ) - grad_output = grad_output.half() - hidden_states = hidden_states.half() - inv_rms = inv_rms.half() return grad_input, grad_weight, grad_bias, None, None diff --git a/deeplink_ext/patch_lightllm.py b/deeplink_ext/patch_lightllm.py index d371870a..4aeee369 100644 --- a/deeplink_ext/patch_lightllm.py +++ b/deeplink_ext/patch_lightllm.py @@ -51,7 +51,18 @@ def patch_token_softmax_reducev_inference(): ) def patch_rms_norm_lightllm(): - rms_norm_pack.rmsnorm_forward = ext.rms_norm_lightllm + import torch + + def rms_norm_lightllm(x, weight, eps): + output = torch.empty_like(x) + inv_rms_dtype = torch.float16 if x.dtype == torch.bfloat16 else x.dtype + inv_rms = torch.empty_like(x, dtype=inv_rms_dtype) + bias = torch.empty_like(weight) + ext.rms_norm(output, inv_rms, x, weight.shape, weight, bias, eps) + + return output + + rms_norm_pack.rmsnorm_forward = rms_norm_lightllm def patch_rotary_emb(): rotary_emb_pack.rotary_emb_fwd = ext.rotary_emb diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index 0e6b911c..ba57369b 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -1,7 +1,7 @@ # Copyright (c) 2023, DeepLink. import torch -import deeplink_ext.cpp_extensions as ext +from deeplink_ext.internlm_ops.rms_norm.deeplink import rms_norm, rms_norm_backward # 定义输入张量 input = torch.randn(5, 5, requires_grad=True).cuda() @@ -17,26 +17,10 @@ normalized_shape = torch.tensor([5, 5], dtype=torch.long).cuda() print(input.is_dipu) -output = torch.empty_like(input) -inv_rms_shape = list(input.shape[:-1]) + [1] -inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=input.device) -ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, 1e-6) - -# 使用 RMS normalization 反向传播 -grad_input = torch.empty_like(grad_output) -grad_weight = torch.empty_like(weight) -grad_bias = torch.empty_like(bias) -ext.rms_norm_backward( - grad_input, - grad_weight, - grad_bias, - grad_output, - input, - weight, - bias, - inv_rms, - weight.shape, - 1e-6, +output, inv_rms = rms_norm(input, None, weight, bias, 1e-6) + +grad_input, grad_weight, grad_bias = rms_norm_backward( + input, grad_output, inv_rms, None, weight, bias, 1e-6 ) print("Output:", output)