From 7fc26c7e55afffe0cabe97bcbadaa0386ead0186 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 11:07:28 +0800 Subject: [PATCH] fix rms norm --- deeplink_ext/internlm_ops/rms_norm/deeplink.py | 9 ++++----- tests/test_rms_lightlm.py | 9 ++++++++- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index 9403340d..d42208bc 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -11,11 +11,11 @@ class _DeepLinkRMSNormFunction(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps): output = torch.empty_like(hidden_states) - inv_rms_shape = list(hidden_states.shape[:-1], 1) + inv_rms_shape = list(hidden_states.shape[:-1]) + [1] inv_rms = torch.empty( inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device ) - ext.rms_norm(output, inv_rms, hidden_states, None, weight, bias, eps) + ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, bias, eps) ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) return output @@ -28,7 +28,6 @@ def backward(ctx, grad_output): grad_input = torch.empty_like(hidden_states) grad_weight = torch.empty_like(weight) grad_bias = torch.empty_like(bias) - ext.rms_norm_backward( grad_input, grad_weight, @@ -38,7 +37,7 @@ def backward(ctx, grad_output): weight, bias, inv_rms, - None, + weight.shape, eps, ) return grad_input, grad_weight, grad_bias, None @@ -48,7 +47,7 @@ class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): output = torch.empty_like(hidden_states, dtype=torch.float32) - inv_rms_shape = list(hidden_states.shape[:-1], 1) + inv_rms_shape = list(hidden_states.shape[:-1]) + [1] inv_rms = torch.empty( inv_rms_shape, dtype=torch.float32, device=hidden_states.device ) diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index a5ee2baa..0e6b911c 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -22,7 +22,6 @@ inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=input.device) ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, 1e-6) - # 使用 RMS normalization 反向传播 grad_input = torch.empty_like(grad_output) grad_weight = torch.empty_like(weight) @@ -44,5 +43,13 @@ print("Grad Input:", grad_input) print("Grad Weight:", grad_weight) print("Grad Bias:", grad_bias) + +input.requires_grad_(True) +weight.requires_grad_(True) +bias.requires_grad_(True) b = input * torch.rsqrt(input.pow(2).mean(-1, keepdim=True) + 1e-6) * weight +grads = torch.autograd.grad(b, [input, weight, bias], grad_output, allow_unused=True) assert torch.allclose(output, b) +assert torch.allclose(grad_input, grads[0]) +assert torch.allclose(grad_weight, grads[1]) +# assert torch.allclose(grad_bias, grads[2])