diff --git a/deeplink_ext/llm_ops_for_ascend_speed.py b/deeplink_ext/llm_ops_for_ascend_speed.py index 033bdb42..52dc5397 100644 --- a/deeplink_ext/llm_ops_for_ascend_speed.py +++ b/deeplink_ext/llm_ops_for_ascend_speed.py @@ -9,6 +9,8 @@ assert hasattr(ext, "rms_norm") and hasattr(ext, "rms_norm_backward") assert hasattr(ext, "adamw") +def is_nan(x): + return torch.isnan(x).any().item() class DeepLinkFlashSelfAttention(torch.autograd.Function): @staticmethod @@ -128,7 +130,7 @@ def backward(ctx, grad_output): grad_input, grad_weight, grad_bias = ext.rms_norm_backward( hidden_states, grad_output, inv_rms, None, weight, bias, ctx.eps ) - return grad_input, grad_weight, grad_bias, None + return grad_input, grad_weight, None, None def adamw_for_ascend_speed(