diff --git a/deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py b/deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py index db43ab6e..277f834f 100644 --- a/deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py +++ b/deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py @@ -11,7 +11,7 @@ class DeepLinkFlashAttentionQKVPackedFunc(torch.autograd.Function): def forward(ctx, qkv, dropout_p, softmax_scale, causal): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) - head_num = qkv.shape[2] + head_num = qkv.shape[3] ( out, attention_mask,