Skip to content

Commit

Permalink
modify rms norm
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangzefeng92 committed Mar 28, 2024
1 parent 7fc26c7 commit 8eca230
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 112 deletions.
18 changes: 0 additions & 18 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,21 +197,6 @@ void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty,
p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch);
}

// For lightllm, rms_norm reuses the diopi implementation of internlm
auto extRmsNormLightllm(const at::Tensor& x, const at::Tensor& weight,
float eps) {
at::ScalarType acc_type = x.scalar_type();
if (x.scalar_type() == at::kBFloat16 || x.scalar_type() == at::kHalf) {
acc_type = at::kFloat;
}
auto inv_rms = at::empty_like(x, acc_type);
auto out = at::empty_like(x);
auto bias = at::empty_like(weight);
at::OptionalIntArrayRef normalized_shape = weight.sizes();
callDiopi(diopiRMSNorm, out, inv_rms, x, normalized_shape, weight, bias, eps);
return out;
}

// For lightllm, rotary_embedding reuses the diopi implementation of internlm
void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) {
auto seq_len = q.size(0);
Expand All @@ -229,9 +214,6 @@ void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
if (&diopiRMSNorm != nullptr) { // Check if weak symbol defined
m.def("rms_norm", &extRmsNorm, "deeplink ext_rms_norm");
m.def("rms_norm_lightllm", &extRmsNormLightllm,
"deeplink ext_rms_norm for lightllm", py::arg("x"), py::arg("weight"),
py::arg("eps"));
}
if (&diopiRMSNormBackward != nullptr) {
m.def("rms_norm_backward", &extRmsNormBackward,
Expand Down
23 changes: 0 additions & 23 deletions csrc/pybind_type_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,5 @@ using OptionalIntArray = c10::optional<IntArray>;

} // namespace dipu::dipu_ext

namespace pybind11::detail {

namespace py = pybind11;

template <>
struct type_caster<at::OptionalIntArrayRef> {
public:
PYBIND11_TYPE_CASTER(dipu::dipu_ext::OptionalIntArray, _("OptionalIntArray"));

bool load(py::handle src, bool /*unused*/) {
if (PyList_Check(src.ptr())) {
value = py::cast<dipu::dipu_ext::IntArray>(src);
return true;
}
if (src.is_none()) {
value = c10::nullopt;
return true;
}
return false;
}
};

} // namespace pybind11::detail

#endif /* end of include guard: PYBIND_TYPE_CAST_H_PXMGELYW */
135 changes: 86 additions & 49 deletions deeplink_ext/internlm_ops/rms_norm/deeplink.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,110 @@
# Copyright (c) 2024, DeepLink.

import torch
import deeplink_ext.cpp_extensions as ext
import deeplink_ext.cpp_extensions as cpp_ext

assert hasattr(ext, "rms_norm")
assert hasattr(cpp_ext, "rms_norm")


# 定义自定义的 autograd 函数
class _DeepLinkRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, hidden_states, weight, bias, eps):
output = torch.empty_like(hidden_states)
inv_rms_shape = list(hidden_states.shape[:-1]) + [1]
inv_rms = torch.empty(
inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device
)
ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, bias, eps)
def rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps):
if None == normalized_shape:
cpp_ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, eps)
else:
cpp_ext.rms_norm(output, inv_rms, input, normalized_shape, weight, bias, eps)

ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps))
return output

@staticmethod
def backward(ctx, grad_output):
hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors
eps = eps_tensor.item()
def rms_norm(input, normalized_shape, weight, bias, eps):
output = torch.empty_like(input)
inv_rms_shape = list(input.shape[:-1]) + [1]
inv_rms = torch.empty(inv_rms_shape, dtype=input.dtype, device=input.device)
rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps)

return [output, inv_rms]

grad_input = torch.empty_like(hidden_states)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
ext.rms_norm_backward(

def rms_norm_backward_out(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
):
if None == normalized_shape:
cpp_ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
hidden_states,
input,
weight,
bias,
inv_rms,
weight.shape,
eps,
)
else:
cpp_ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
)


def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bias, eps):
grad_input = torch.empty_like(input)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
rms_norm_backward_out(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
)

return [grad_input, grad_weight, grad_bias]


# 定义自定义的 autograd 函数
class _DeepLinkRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, hidden_states, weight, bias, eps):
output, inv_rms = rms_norm(hidden_states, None, weight, bias, eps)
ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps))
return output

@staticmethod
def backward(ctx, grad_output):
hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors
eps = eps_tensor.item()
grad_input, grad_weight, grad_bias = rms_norm_backward(
hidden_states, grad_output, inv_rms, None, weight, bias, eps
)
return grad_input, grad_weight, grad_bias, None


class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function):
@staticmethod
def forward(ctx, hidden_states, weight, bias, eps, normalized_shape):
output = torch.empty_like(hidden_states, dtype=torch.float32)
inv_rms_shape = list(hidden_states.shape[:-1]) + [1]
inv_rms = torch.empty(
inv_rms_shape, dtype=torch.float32, device=hidden_states.device
)
ext.rms_norm(
output,
inv_rms,
hidden_states.float(),
normalized_shape,
weight.float(),
bias.float(),
eps,
output, inv_rms = rms_norm(
hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps
)
output = output.half()
inv_rms = inv_rms.half()
Expand All @@ -72,24 +118,15 @@ def backward(ctx, grad_output):
eps = eps_tensor.item()
normalized_shape = ctx.intermediate_results

grad_input = torch.empty_like(hidden_states, dtype=torch.float32)
grad_weight = torch.empty_like(weight, dtype=torch.float32)
grad_bias = torch.empty_like(bias, dtype=torch.float32)
ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output.float(),
grad_input, grad_weight, grad_bias = rms_norm_backward(
hidden_states.float(),
weight.float(),
bias.float(),
grad_output.float(),
inv_rms.float(),
normalized_shape,
weight.float(),
bias.float(),
eps,
)
grad_output = grad_output.half()
hidden_states = hidden_states.half()
inv_rms = inv_rms.half()
return grad_input, grad_weight, grad_bias, None, None


Expand Down
13 changes: 12 additions & 1 deletion deeplink_ext/patch_lightllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,18 @@ def patch_token_softmax_reducev_inference():
)

def patch_rms_norm_lightllm():
rms_norm_pack.rmsnorm_forward = ext.rms_norm_lightllm
import torch

def rms_norm_lightllm(x, weight, eps):
output = torch.empty_like(x)
inv_rms_dtype = torch.float16 if x.dtype == torch.bfloat16 else x.dtype
inv_rms = torch.empty_like(x, dtype=inv_rms_dtype)
bias = torch.empty_like(weight)
ext.rms_norm(output, inv_rms, x, weight.shape, weight, bias, eps)

return output

rms_norm_pack.rmsnorm_forward = rms_norm_lightllm

def patch_rotary_emb():
rotary_emb_pack.rotary_emb_fwd = ext.rotary_emb
Expand Down
26 changes: 5 additions & 21 deletions tests/test_rms_lightlm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2023, DeepLink.

import torch
import deeplink_ext.cpp_extensions as ext
from deeplink_ext.internlm_ops.rms_norm.deeplink import rms_norm, rms_norm_backward

# 定义输入张量
input = torch.randn(5, 5, requires_grad=True).cuda()
Expand All @@ -17,26 +17,10 @@
normalized_shape = torch.tensor([5, 5], dtype=torch.long).cuda()

print(input.is_dipu)
output = torch.empty_like(input)
inv_rms_shape = list(input.shape[:-1]) + [1]
inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=input.device)
ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, 1e-6)

# 使用 RMS normalization 反向传播
grad_input = torch.empty_like(grad_output)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
weight.shape,
1e-6,
output, inv_rms = rms_norm(input, None, weight, bias, 1e-6)

grad_input, grad_weight, grad_bias = rms_norm_backward(
input, grad_output, inv_rms, None, weight, bias, 1e-6
)

print("Output:", output)
Expand Down

0 comments on commit 8eca230

Please sign in to comment.