Skip to content

Commit

Permalink
reset RMSNorm.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jingguo-st authored Jan 18, 2024
1 parent b1cced5 commit 2ad5f10
Showing 1 changed file with 2 additions and 11 deletions.
13 changes: 2 additions & 11 deletions ext_apply/internlm/RMSNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def forward(ctx, hidden_states, weight, bias, eps):
def backward(ctx, grad_output):
hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors
eps = eps_tensor.item()
grad_input, *others = deeplink_ext.rms_norm_backward(
grad_input, grad_weight, grad_bias = deeplink_ext.rms_norm_backward(
hidden_states,
grad_output,
inv_rms,
Expand All @@ -53,10 +53,6 @@ def backward(ctx, grad_output):
bias,
eps
)
if isinstance(others, (list, tuple)) and len(others) == 2:
grad_weight, grad_bias = others
else:
grad_weight, grad_bias = others[0], None
return grad_input, grad_weight, grad_bias, None

class _DeepLinkRMSNormFunction_WithNormalizedShape(torch.autograd.Function):
Expand Down Expand Up @@ -88,7 +84,7 @@ def backward(ctx, grad_output):
weight = weight.float()
bias = bias.float()
grad_output = grad_output.float()
grad_input, *others = deeplink_ext.rms_norm_backward(
grad_input, grad_weight, grad_bias = deeplink_ext.rms_norm_backward(
hidden_states,
grad_output,
inv_rms,
Expand All @@ -97,11 +93,6 @@ def backward(ctx, grad_output):
bias,
eps
)
if isinstance(others, (tuple, list)) and len(others) == 2:
grad_weight, grad_bias = others
else:
grad_weight = others[0], None

grad_output = grad_output.half()
hidden_states = hidden_states.half()
inv_rms = inv_rms.half()
Expand Down

0 comments on commit 2ad5f10

Please sign in to comment.