From c8f2125a575b0e75b8e490f55292c8263746248e Mon Sep 17 00:00:00 2001 From: jingguo-st Date: Thu, 18 Jan 2024 11:44:13 +0800 Subject: [PATCH 1/4] 1. fix extRmsNorm: inv_rms shape 2. support Ascend RMSNorm and RMSNormGrad that have no bias: DeeplinkRMSNorm, DeeplinkRMSNorm_WithNormalizedShape, 3. rewrite test case for rms_norm --- ext_apply/internlm/RMSNorm.py | 14 ++++++- ext_op/example_ext.cpp | 5 ++- test/test_rms_internlm.py | 77 ++++++++++------------------------- 3 files changed, 38 insertions(+), 58 deletions(-) diff --git a/ext_apply/internlm/RMSNorm.py b/ext_apply/internlm/RMSNorm.py index b6d13def..cb0d8a10 100644 --- a/ext_apply/internlm/RMSNorm.py +++ b/ext_apply/internlm/RMSNorm.py @@ -44,7 +44,7 @@ def forward(ctx, hidden_states, weight, bias, eps): def backward(ctx, grad_output): hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors eps = eps_tensor.item() - grad_input, grad_weight, grad_bias = deeplink_ext.rms_norm_backward( + grad_input, *others = deeplink_ext.rms_norm_backward( hidden_states, grad_output, inv_rms, @@ -53,6 +53,10 @@ def backward(ctx, grad_output): bias, eps ) + if isinstance(others, (list, tuple)) and len(others) == 2: + grad_weight, grad_bias = others + else: + grad_weight, grad_bias = others[0], None return grad_input, grad_weight, grad_bias, None class _DeeplinkRMSNormFunction_WithNormalizedShape(torch.autograd.Function): @@ -84,7 +88,7 @@ def backward(ctx, grad_output): weight = weight.float() bias = bias.float() grad_output = grad_output.float() - grad_input, grad_weight, grad_bias = deeplink_ext.rms_norm_backward( + grad_input, *others = deeplink_ext.rms_norm_backward( hidden_states, grad_output, inv_rms, @@ -93,6 +97,12 @@ def backward(ctx, grad_output): bias, eps ) + if isinstance(others, (tuple, list)) and len(others) == 2: + grad_weight, grad_bias = others + else: + grad_weight = others[0] + grad_bias = None + grad_output = grad_output.half() hidden_states = hidden_states.half() inv_rms = inv_rms.half() diff --git a/ext_op/example_ext.cpp b/ext_op/example_ext.cpp index 44d3056a..7bd01973 100644 --- a/ext_op/example_ext.cpp +++ b/ext_op/example_ext.cpp @@ -43,7 +43,10 @@ auto extRmsNorm(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, double eps) { at::OptionalIntArrayRef normalized_shape_at = optionalIntArrayToIntArrayRefOrDefault(normalized_shape, weight.sizes()); - auto inv_rms = at::empty_like(input); + auto input_shape = input.sizes(); + std::vector input_size(input_shape.size(), 1); + std::copy(input_shape.begin(), input_shape.end() - 1, input_size.begin()); + auto inv_rms = at::empty(input_size, input.options()); auto output = at::empty_like(input); callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, bias, eps); diff --git a/test/test_rms_internlm.py b/test/test_rms_internlm.py index 452d9e64..d476babb 100644 --- a/test/test_rms_internlm.py +++ b/test/test_rms_internlm.py @@ -1,67 +1,34 @@ -from DeepLinkExt.ext_apply.internlm.RMSNorm import ( - InternLMRMSNorm, - DeeplinkRMSNorm, - DeeplinkRMSNorm_WithNormalizedShape, -) import torch -from torch import nn import torch_dipu import numpy as np +from ext_apply.internlm.RMSNorm import ( + InternLMRMSNorm, + DeeplinkRMSNorm, + DeeplinkRMSNorm_WithNormalizedShape, +) -def test_forward_backward(Basenet, Testnet, rtol=1e-5, atol=1e-5): - input = torch.randn(5, 5, requires_grad=True).cuda() - input_dipu = input.clone() - hidden_size = 5 - - intern = Basenet(hidden_size).cuda() - deep = Testnet(hidden_size).cuda() - y_intern = intern(input) - y_dipu = deep(input_dipu) - - y_label = torch.ones_like(y_intern) - print( - "Are the prediction identical?:", - np.allclose( - y_intern.detach().cpu().numpy(), - y_dipu.detach().cpu().numpy(), - rtol, - atol, - True, - ), - ) +def test_rms_norm(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): + x_base = torch.randn(5, 5, requires_grad=True).cuda() + x_base.retain_grad() - loss_fn = torch.nn.MSELoss() + x_intern = x_base.clone() + x_intern.retain_grad() - loss = loss_fn(y_label, y_intern) - input.retain_grad() - loss.backward() - # print("\nGradient for 'input':") - input_grad = input.grad - # print(input_grad) + hidden_szie = 5 - loss2 = loss_fn(y_label, y_dipu) - input_dipu.retain_grad() - loss2.backward() - # print("\nGradient for 'input_dipu':") - input_dipu_grad = input_dipu.grad - # print(input_dipu_grad) + model_base = BaseRmsNorm(hidden_szie).cuda() + out_base = model_base(x_base) + out_base.backward(torch.ones_like(x_base)) + grad_x_base = x_base.grad.cpu().numpy() - # 对比两者是否一致 - print( - "Are the gradients identical?:", - np.allclose( - input_grad.detach().cpu().numpy(), - input_dipu_grad.detach().cpu().numpy(), - rtol, - atol, - True, - ), - ) + model_deeplink = DeeplinkRmsNorm(hidden_szie).cuda() + out_deeplink = model_deeplink(x_intern) + out_deeplink.backward(torch.ones_like(x_base)) + grad_x_intern = x_intern.grad.cpu().numpy() + return np.allclose(grad_x_base, grad_x_intern, rtol, atol, True) -print("\nTest case: normalized_shape == None:") -test_forward_backward(InternLMRMSNorm, DeeplinkRMSNorm) -print("\nTest case: normalized_shape == weight.size():") -test_forward_backward(InternLMRMSNorm, DeeplinkRMSNorm_WithNormalizedShape) +print("Test case: normalized_shape == None: grad_inputs closed ? ", test_rms_norm(InternLMRMSNorm, DeeplinkRMSNorm)) +print("Test case: normalized_shape == weight.size(): grad_inputs closed ? ", test_rms_norm(InternLMRMSNorm, DeeplinkRMSNorm_WithNormalizedShape)) From a39abfd92a19317e3b9560e6cab7ae52c9c1fc01 Mon Sep 17 00:00:00 2001 From: jingguo-st Date: Thu, 18 Jan 2024 11:58:54 +0800 Subject: [PATCH 2/4] support ascend rms_norm --- ext_apply/internlm/RMSNorm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ext_apply/internlm/RMSNorm.py b/ext_apply/internlm/RMSNorm.py index 2003a4e1..95df6aca 100644 --- a/ext_apply/internlm/RMSNorm.py +++ b/ext_apply/internlm/RMSNorm.py @@ -100,8 +100,7 @@ def backward(ctx, grad_output): if isinstance(others, (tuple, list)) and len(others) == 2: grad_weight, grad_bias = others else: - grad_weight = others[0] - grad_bias = None + grad_weight = others[0], None grad_output = grad_output.half() hidden_states = hidden_states.half() From b1cced51750283497b29359de57f5b8b85c2290b Mon Sep 17 00:00:00 2001 From: Fu Jingguo Date: Thu, 18 Jan 2024 13:06:10 +0800 Subject: [PATCH 3/4] support ascend rms_norm --- ext_op/example_ext.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext_op/example_ext.cpp b/ext_op/example_ext.cpp index 7ba11d3b..34713501 100644 --- a/ext_op/example_ext.cpp +++ b/ext_op/example_ext.cpp @@ -44,8 +44,8 @@ auto extRmsNorm(const at::Tensor& input, at::OptionalIntArrayRef normalized_shape_at = optionalIntArrayToIntArrayRefOrDefault(normalized_shape, weight.sizes()); auto input_shape = input.sizes(); - std::vector input_size(input_shape.size(), 1); - std::copy(input_shape.begin(), input_shape.end() - 1, input_size.begin()); + std::vector input_size(input_shape.begin(), input_shape.end()); + input_size.back() = 1; auto inv_rms = at::empty(input_size, input.options()); auto output = at::empty_like(input); callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, From 2ad5f105b459c059665f8d4bd03a846d956c6fee Mon Sep 17 00:00:00 2001 From: Fu Jingguo Date: Thu, 18 Jan 2024 13:13:14 +0800 Subject: [PATCH 4/4] reset RMSNorm.py --- ext_apply/internlm/RMSNorm.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/ext_apply/internlm/RMSNorm.py b/ext_apply/internlm/RMSNorm.py index 95df6aca..d2e3453c 100644 --- a/ext_apply/internlm/RMSNorm.py +++ b/ext_apply/internlm/RMSNorm.py @@ -44,7 +44,7 @@ def forward(ctx, hidden_states, weight, bias, eps): def backward(ctx, grad_output): hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors eps = eps_tensor.item() - grad_input, *others = deeplink_ext.rms_norm_backward( + grad_input, grad_weight, grad_bias = deeplink_ext.rms_norm_backward( hidden_states, grad_output, inv_rms, @@ -53,10 +53,6 @@ def backward(ctx, grad_output): bias, eps ) - if isinstance(others, (list, tuple)) and len(others) == 2: - grad_weight, grad_bias = others - else: - grad_weight, grad_bias = others[0], None return grad_input, grad_weight, grad_bias, None class _DeepLinkRMSNormFunction_WithNormalizedShape(torch.autograd.Function): @@ -88,7 +84,7 @@ def backward(ctx, grad_output): weight = weight.float() bias = bias.float() grad_output = grad_output.float() - grad_input, *others = deeplink_ext.rms_norm_backward( + grad_input, grad_weight, grad_bias = deeplink_ext.rms_norm_backward( hidden_states, grad_output, inv_rms, @@ -97,11 +93,6 @@ def backward(ctx, grad_output): bias, eps ) - if isinstance(others, (tuple, list)) and len(others) == 2: - grad_weight, grad_bias = others - else: - grad_weight = others[0], None - grad_output = grad_output.half() hidden_states = hidden_states.half() inv_rms = inv_rms.half()