Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: RMSNorm #59

Merged
merged 43 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
23f7a96
update cpp ext of rms norm
zhangzefeng92 Mar 26, 2024
b0aa03d
Update test_rms_lightlm.py
zhangzefeng92 Mar 27, 2024
0517125
Update test_rms_lightlm.py
zhangzefeng92 Mar 27, 2024
d675c9f
Update extensions.cpp
zhangzefeng92 Mar 27, 2024
3f100c0
modify extensions.cpp
zhangzefeng92 Mar 27, 2024
fce081f
fix python lint
zhangzefeng92 Mar 27, 2024
7e59912
fix python lint
zhangzefeng92 Mar 27, 2024
bd74cb0
fix python lint
zhangzefeng92 Mar 27, 2024
5027589
fix python lint
zhangzefeng92 Mar 27, 2024
7fc26c7
fix rms norm
zhangzefeng92 Mar 27, 2024
136e8e1
modify rms norm
zhangzefeng92 Mar 28, 2024
9277e0a
update cpp ext of rms norm
zhangzefeng92 Mar 26, 2024
1be8963
Update test_rms_lightlm.py
zhangzefeng92 Mar 27, 2024
6cd9db6
Update test_rms_lightlm.py
zhangzefeng92 Mar 27, 2024
fd486a6
Update extensions.cpp
zhangzefeng92 Mar 27, 2024
d54413b
modify extensions.cpp
zhangzefeng92 Mar 27, 2024
6dad1ed
fix python lint
zhangzefeng92 Mar 27, 2024
0a8a28d
fix python lint
zhangzefeng92 Mar 27, 2024
4ef9d48
fix python lint
zhangzefeng92 Mar 27, 2024
241492b
fix python lint
zhangzefeng92 Mar 27, 2024
5cafafe
fix rms norm
zhangzefeng92 Mar 27, 2024
3d9a805
modify rms norm
zhangzefeng92 Mar 28, 2024
33d3618
Merge branch 'main' into zzf/fix_rmsnorm
zhangzefeng92 Mar 29, 2024
28274a0
Merge branch 'zzf/fix_rmsnorm' of https://github.com/DeepLink-org/Dee…
zhangzefeng92 Apr 1, 2024
27da2ec
modify rms norm
zhangzefeng92 Apr 1, 2024
3f894a2
modify rms norm
zhangzefeng92 Apr 1, 2024
a5590c8
Merge branch 'main' into zzf/fix_rmsnorm
yangbofun Apr 1, 2024
52bc927
lint
yangbofun Apr 1, 2024
5ac11ca
delete the duplicated
yangbofun Apr 1, 2024
4a31673
delete
yangbofun Apr 1, 2024
8ab1be0
Update __init__.py
zhangzefeng92 Apr 1, 2024
c314b13
modify test
yangbofun Apr 1, 2024
645c8da
Merge branch 'zzf/fix_rmsnorm' of https://github.com/DeepLink-org/Dee…
yangbofun Apr 1, 2024
dc674ca
modify
yangbofun Apr 1, 2024
7ce2b22
modify rotary_embeding
yangbofun Apr 1, 2024
03d5992
modify rotary_embeding
yangbofun Apr 1, 2024
92bcf46
modify rotary_embeding
yangbofun Apr 1, 2024
b0965fb
lint
yangbofun Apr 1, 2024
e979201
fix
yangbofun Apr 1, 2024
a3cd9da
fix
yangbofun Apr 1, 2024
ca0b33c
modify mha
yangbofun Apr 1, 2024
6d798aa
rename rotary_embedding
yangbofun Apr 1, 2024
3380cfb
lint
yangbofun Apr 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 9 additions & 34 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <cstdint>
#include <tuple>
#include <utility>
#include <vector>

#include <ATen/core/ATen_fwd.h>
#include <ATen/core/Generator.h>
Expand All @@ -27,43 +26,22 @@

namespace dipu::dipu_ext {

namespace {

at::IntArrayRef optionalIntArrayToIntArrayRefOrDefault(
const OptionalIntArray& opt, at::IntArrayRef def) {
if (opt) {
return {*opt};
}
return def;
}

} // namespace

auto extRmsNorm(const at::Tensor& input,
auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms,
const at::Tensor& input,
const OptionalIntArray& normalized_shape,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个签名需要修改

那个头文件可以完全删除

const at::Tensor& weight, const at::Tensor& bias, double eps) {
at::OptionalIntArrayRef normalized_shape_at =
optionalIntArrayToIntArrayRefOrDefault(normalized_shape, weight.sizes());
auto input_shape = input.sizes();
std::vector<int64_t> input_size(input_shape.begin(), input_shape.end());
input_size.back() = 1;
auto inv_rms = at::empty(input_size, input.options());
auto output = at::empty_like(input);
at::OptionalIntArrayRef normalized_shape_at = *normalized_shape;
callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight,
bias, eps);
return std::make_tuple(std::move(output), std::move(inv_rms));
}

auto extRmsNormBackward(const at::Tensor& input, const at::Tensor& grad_output,
const at::Tensor& inv_rms,
const OptionalIntArray& normalized_shape,
const at::Tensor& weight, const at::Tensor& bias,
double eps) {
at::OptionalIntArrayRef normalized_shape_at =
optionalIntArrayToIntArrayRefOrDefault(normalized_shape, weight.sizes());
auto grad_input = at::empty_like(grad_output);
auto grad_weight = at::empty_like(weight);
auto grad_bias = at::empty_like(bias);
auto extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight,
at::Tensor& grad_bias, const at::Tensor& grad_output,
const at::Tensor& input, const at::Tensor& weight,
const at::Tensor& bias, const at::Tensor& inv_rms,
const OptionalIntArray& normalized_shape, double eps) {
at::OptionalIntArrayRef normalized_shape_at = *normalized_shape;
callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias,
grad_output, input, weight, bias, inv_rms, normalized_shape_at,
eps);
Expand Down Expand Up @@ -241,9 +219,6 @@ auto extRmsNormLightllm(const at::Tensor& x, const at::Tensor& weight,
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
if (&diopiRMSNorm != nullptr) { // Check if weak symbol defined
m.def("rms_norm", &extRmsNorm, "deeplink ext_rms_norm");
m.def("rms_norm_lightllm", &extRmsNormLightllm,
"deeplink ext_rms_norm for lightllm", py::arg("x"), py::arg("weight"),
py::arg("eps"));
}
if (&diopiRMSNormBackward != nullptr) {
m.def("rms_norm_backward", &extRmsNormBackward,
Expand Down
24 changes: 0 additions & 24 deletions csrc/pybind_type_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,4 @@ using OptionalIntArray = c10::optional<IntArray>;

} // namespace dipu::dipu_ext

namespace pybind11::detail {

namespace py = pybind11;

template <>
struct type_caster<at::OptionalIntArrayRef> {
public:
PYBIND11_TYPE_CASTER(dipu::dipu_ext::OptionalIntArray, _("OptionalIntArray"));

bool load(py::handle src, bool /*unused*/) {
if (PyList_Check(src.ptr())) {
value = py::cast<dipu::dipu_ext::IntArray>(src);
return true;
}
if (src.is_none()) {
value = c10::nullopt;
return true;
}
return false;
}
};

} // namespace pybind11::detail

#endif /* end of include guard: PYBIND_TYPE_CAST_H_PXMGELYW */
4 changes: 4 additions & 0 deletions deeplink_ext/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .deeplink import rms_norm_out, rms_norm, rms_norm_backward_out, rms_norm_backward


__all__ = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"]
78 changes: 78 additions & 0 deletions deeplink_ext/common/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
import deeplink_ext.cpp_extensions as cpp_ext


def rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps):
if None == normalized_shape:
cpp_ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, eps)
else:
cpp_ext.rms_norm(output, inv_rms, input, normalized_shape, weight, bias, eps)


def rms_norm(input, normalized_shape, weight, bias, eps):
output = torch.empty_like(input)
inv_rms_shape = list(input.shape[:-1]) + [1]
inv_rms = torch.empty(inv_rms_shape, dtype=input.dtype, device=input.device)
rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps)

return [output, inv_rms]


def rms_norm_backward_out(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
):
if None == normalized_shape:
cpp_ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
weight.shape,
eps,
)
else:
cpp_ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
)


def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bias, eps):
grad_input = torch.empty_like(input)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)
rms_norm_backward_out(
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape,
eps,
)

return [grad_input, grad_weight, grad_bias]
39 changes: 37 additions & 2 deletions deeplink_ext/internlm_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,40 @@
# Copyright (c) 2024, DeepLink.

from . import mha, rms_norm, rotary
from . import mha

__all__ = ["mha", "rms_norm", "rotary"]

_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation."


try:
from .rms_norm import RMSNorm, RMSNormWithNormalizedShape
except:
print(
_not_impl.format(op_name="RMSNorm or RMSNormWithNormalizedShape"),
)
from .rms_norm_fallback import (
RMSNorm,
RMSNormWithNormalizedShape,
)


try:
from .rotary_embedding import apply_rotary
except:
print(_not_impl.format(op_name="apply_rotary"))
from .rotary_embeddinig_fallback import apply_rotary


try:
from .mha import SelfAttention, CrossAttention
except Exception as e:
print(_not_impl.format(op_name="mha"))
from .mha_fallback import SelfAttention, CrossAttention

__all__ = [
"SelfAttention",
"CrossAttention",
"RMSNorm",
"RMSNormWithNormalizedShape",
"apply_rotary",
]
Loading
Loading