Skip to content

Commit

Permalink
Merge pull request #35 from DeepLink-org/fuj/fix-rms-norm
Browse files Browse the repository at this point in the history
Fuj/support ascend rms_norm
  • Loading branch information
mrdanielw authored Jan 18, 2024
2 parents ee45ff3 + 2ad5f10 commit 7ef9c24
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 56 deletions.
5 changes: 4 additions & 1 deletion ext_op/example_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ auto extRmsNorm(const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& bias, double eps) {
at::OptionalIntArrayRef normalized_shape_at =
optionalIntArrayToIntArrayRefOrDefault(normalized_shape, weight.sizes());
auto inv_rms = at::empty_like(input);
auto input_shape = input.sizes();
std::vector<int64_t> input_size(input_shape.begin(), input_shape.end());
input_size.back() = 1;
auto inv_rms = at::empty(input_size, input.options());
auto output = at::empty_like(input);
callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight,
bias, eps);
Expand Down
77 changes: 22 additions & 55 deletions test/test_rms_internlm.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,34 @@
from DeepLinkExt.ext_apply.internlm.RMSNorm import (
InternLMRMSNorm,
DeepLinkRMSNorm,
DeepLinkRMSNorm_WithNormalizedShape,
)
import torch
from torch import nn
import torch_dipu
import numpy as np

from ext_apply.internlm.RMSNorm import (
InternLMRMSNorm,
DeeplinkRMSNorm,
DeeplinkRMSNorm_WithNormalizedShape,
)

def test_forward_backward(Basenet, Testnet, rtol=1e-5, atol=1e-5):
input = torch.randn(5, 5, requires_grad=True).cuda()
input_dipu = input.clone()
hidden_size = 5

intern = Basenet(hidden_size).cuda()
deep = Testnet(hidden_size).cuda()
y_intern = intern(input)
y_dipu = deep(input_dipu)

y_label = torch.ones_like(y_intern)

print(
"Are the prediction identical?:",
np.allclose(
y_intern.detach().cpu().numpy(),
y_dipu.detach().cpu().numpy(),
rtol,
atol,
True,
),
)
def test_rms_norm(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3):
x_base = torch.randn(5, 5, requires_grad=True).cuda()
x_base.retain_grad()

loss_fn = torch.nn.MSELoss()
x_intern = x_base.clone()
x_intern.retain_grad()

loss = loss_fn(y_label, y_intern)
input.retain_grad()
loss.backward()
# print("\nGradient for 'input':")
input_grad = input.grad
# print(input_grad)
hidden_szie = 5

loss2 = loss_fn(y_label, y_dipu)
input_dipu.retain_grad()
loss2.backward()
# print("\nGradient for 'input_dipu':")
input_dipu_grad = input_dipu.grad
# print(input_dipu_grad)
model_base = BaseRmsNorm(hidden_szie).cuda()
out_base = model_base(x_base)
out_base.backward(torch.ones_like(x_base))
grad_x_base = x_base.grad.cpu().numpy()

# 对比两者是否一致
print(
"Are the gradients identical?:",
np.allclose(
input_grad.detach().cpu().numpy(),
input_dipu_grad.detach().cpu().numpy(),
rtol,
atol,
True,
),
)
model_deeplink = DeeplinkRmsNorm(hidden_szie).cuda()
out_deeplink = model_deeplink(x_intern)
out_deeplink.backward(torch.ones_like(x_base))
grad_x_intern = x_intern.grad.cpu().numpy()

return np.allclose(grad_x_base, grad_x_intern, rtol, atol, True)

print("\nTest case: normalized_shape == None:")
test_forward_backward(InternLMRMSNorm, DeepLinkRMSNorm)
print("\nTest case: normalized_shape == weight.size():")
test_forward_backward(InternLMRMSNorm, DeepLinkRMSNorm_WithNormalizedShape)
print("Test case: normalized_shape == None: grad_inputs closed ? ", test_rms_norm(InternLMRMSNorm, DeeplinkRMSNorm))
print("Test case: normalized_shape == weight.size(): grad_inputs closed ? ", test_rms_norm(InternLMRMSNorm, DeeplinkRMSNorm_WithNormalizedShape))

0 comments on commit 7ef9c24

Please sign in to comment.