Skip to content

Commit

Permalink
modify rms norm
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangzefeng92 committed Apr 1, 2024
1 parent 28274a0 commit 27da2ec
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 91 deletions.
4 changes: 4 additions & 0 deletions deeplink_ext/common/rms_norm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .deeplink import rms_norm_out, rms_norm, rms_norm_backward_out, rms_norm_backward


all = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"]
79 changes: 79 additions & 0 deletions deeplink_ext/common/rms_norm/deeplink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
import deeplink_ext.cpp_extensions as cpp_ext


def rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps):
if None == normalized_shape:
cpp_ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, eps)
else:
cpp_ext.rms_norm(output, inv_rms, input, normalized_shape, weight, bias, eps)


def rms_norm(input, normalized_shape, weight, bias, eps):
output = torch.empty_like(input)
inv_rms_shape = list(input.shape[:-1]) + [1]
inv_rms = torch.empty(inv_rms_shape, dtype=input.dtype, device=input.device)
rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps)

return [output, inv_rms]


def rms_norm_backward_out(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
):
if None == normalized_shape:
cpp_ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
weight.shape,
eps,
)
else:
cpp_ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
)


def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bias, eps):
grad_input = torch.empty_like(input)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
rms_norm_backward_out(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
)

return [grad_input, grad_weight, grad_bias]

80 changes: 1 addition & 79 deletions deeplink_ext/internlm_ops/rms_norm/deeplink.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,7 @@
# Copyright (c) 2024, DeepLink.

import torch
import deeplink_ext.cpp_extensions as cpp_ext

assert hasattr(cpp_ext, "rms_norm")


def rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps):
if None == normalized_shape:
cpp_ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, eps)
else:
cpp_ext.rms_norm(output, inv_rms, input, normalized_shape, weight, bias, eps)


def rms_norm(input, normalized_shape, weight, bias, eps):
output = torch.empty_like(input)
inv_rms_shape = list(input.shape[:-1]) + [1]
inv_rms = torch.empty(inv_rms_shape, dtype=input.dtype, device=input.device)
rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps)

return [output, inv_rms]


def rms_norm_backward_out(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
):
if None == normalized_shape:
cpp_ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
weight.shape,
eps,
)
else:
cpp_ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
)


def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bias, eps):
grad_input = torch.empty_like(input)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
rms_norm_backward_out(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
)

return [grad_input, grad_weight, grad_bias]
from deeplink_ext.common.rms_norm.deeplink import rms_norm, rms_norm_backward


# 定义自定义的 autograd 函数
Expand Down
14 changes: 2 additions & 12 deletions deeplink_ext/patch_lightllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,8 @@ def patch_token_softmax_reducev_inference():
)

def patch_rms_norm_lightllm():
import torch

def rms_norm_lightllm(x, weight, eps):
output = torch.empty_like(x)
inv_rms_dtype = torch.float16 if x.dtype == torch.bfloat16 else x.dtype
inv_rms = torch.empty_like(x, dtype=inv_rms_dtype)
bias = torch.empty_like(weight)
ext.rms_norm(output, inv_rms, x, weight.shape, weight, bias, eps)

return output

rms_norm_pack.rmsnorm_forward = rms_norm_lightllm
from .common.rms_norm.deeplink import rms_norm
rms_norm_pack.rmsnorm_forward = rms_norm

def patch_rotary_emb():
rotary_emb_pack.rotary_emb_fwd = ext.rotary_emb
Expand Down

0 comments on commit 27da2ec

Please sign in to comment.