Skip to content

Commit

Permalink
fix rms norm
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangzefeng92 committed Mar 27, 2024
1 parent 5027589 commit 7fc26c7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
9 changes: 4 additions & 5 deletions deeplink_ext/internlm_ops/rms_norm/deeplink.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ class _DeepLinkRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, hidden_states, weight, bias, eps):
output = torch.empty_like(hidden_states)
inv_rms_shape = list(hidden_states.shape[:-1], 1)
inv_rms_shape = list(hidden_states.shape[:-1]) + [1]
inv_rms = torch.empty(
inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device
)
ext.rms_norm(output, inv_rms, hidden_states, None, weight, bias, eps)
ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, bias, eps)

ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps))
return output
Expand All @@ -28,7 +28,6 @@ def backward(ctx, grad_output):
grad_input = torch.empty_like(hidden_states)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)

ext.rms_norm_backward(
grad_input,
grad_weight,
Expand All @@ -38,7 +37,7 @@ def backward(ctx, grad_output):
weight,
bias,
inv_rms,
None,
weight.shape,
eps,
)
return grad_input, grad_weight, grad_bias, None
Expand All @@ -48,7 +47,7 @@ class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function):
@staticmethod
def forward(ctx, hidden_states, weight, bias, eps, normalized_shape):
output = torch.empty_like(hidden_states, dtype=torch.float32)
inv_rms_shape = list(hidden_states.shape[:-1], 1)
inv_rms_shape = list(hidden_states.shape[:-1]) + [1]
inv_rms = torch.empty(
inv_rms_shape, dtype=torch.float32, device=hidden_states.device
)
Expand Down
9 changes: 8 additions & 1 deletion tests/test_rms_lightlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=input.device)
ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, 1e-6)


# 使用 RMS normalization 反向传播
grad_input = torch.empty_like(grad_output)
grad_weight = torch.empty_like(weight)
Expand All @@ -44,5 +43,13 @@
print("Grad Input:", grad_input)
print("Grad Weight:", grad_weight)
print("Grad Bias:", grad_bias)

input.requires_grad_(True)
weight.requires_grad_(True)
bias.requires_grad_(True)
b = input * torch.rsqrt(input.pow(2).mean(-1, keepdim=True) + 1e-6) * weight
grads = torch.autograd.grad(b, [input, weight, bias], grad_output, allow_unused=True)
assert torch.allclose(output, b)
assert torch.allclose(grad_input, grads[0])
assert torch.allclose(grad_weight, grads[1])
# assert torch.allclose(grad_bias, grads[2])

0 comments on commit 7fc26c7

Please sign in to comment.