From 7997cf494e8e9c096154c2fa285a373cc215e0ab Mon Sep 17 00:00:00 2001 From: POI-WX Date: Wed, 20 Mar 2024 22:55:52 +0800 Subject: [PATCH] fix bug of return grad bias --- deeplink_ext/llm_ops_for_ascend_speed.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deeplink_ext/llm_ops_for_ascend_speed.py b/deeplink_ext/llm_ops_for_ascend_speed.py index 033bdb42..52dc5397 100644 --- a/deeplink_ext/llm_ops_for_ascend_speed.py +++ b/deeplink_ext/llm_ops_for_ascend_speed.py @@ -9,6 +9,8 @@ assert hasattr(ext, "rms_norm") and hasattr(ext, "rms_norm_backward") assert hasattr(ext, "adamw") +def is_nan(x): + return torch.isnan(x).any().item() class DeepLinkFlashSelfAttention(torch.autograd.Function): @staticmethod @@ -128,7 +130,7 @@ def backward(ctx, grad_output): grad_input, grad_weight, grad_bias = ext.rms_norm_backward( hidden_states, grad_output, inv_rms, None, weight, bias, ctx.eps ) - return grad_input, grad_weight, grad_bias, None + return grad_input, grad_weight, None, None def adamw_for_ascend_speed(