From f98cb27493ab58001e62f56dffac2ba2b1546475 Mon Sep 17 00:00:00 2001 From: POI-WX Date: Wed, 20 Mar 2024 17:42:46 +0800 Subject: [PATCH] fix bug --- deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py b/deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py index db43ab6e..277f834f 100644 --- a/deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py +++ b/deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py @@ -11,7 +11,7 @@ class DeepLinkFlashAttentionQKVPackedFunc(torch.autograd.Function): def forward(ctx, qkv, dropout_p, softmax_scale, causal): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) - head_num = qkv.shape[2] + head_num = qkv.shape[3] ( out, attention_mask,