Skip to content

Commit

Permalink
fix bug of return grad bias
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Mar 20, 2024
1 parent 5891226 commit 7997cf4
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion deeplink_ext/llm_ops_for_ascend_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
assert hasattr(ext, "rms_norm") and hasattr(ext, "rms_norm_backward")
assert hasattr(ext, "adamw")

def is_nan(x):
return torch.isnan(x).any().item()

class DeepLinkFlashSelfAttention(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -128,7 +130,7 @@ def backward(ctx, grad_output):
grad_input, grad_weight, grad_bias = ext.rms_norm_backward(
hidden_states, grad_output, inv_rms, None, weight, bias, ctx.eps
)
return grad_input, grad_weight, grad_bias, None
return grad_input, grad_weight, None, None


def adamw_for_ascend_speed(
Expand Down

0 comments on commit 7997cf4

Please sign in to comment.