Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention 中的问题 #27

Open
baifanxxx opened this issue Oct 16, 2024 · 1 comment
Open

Attention 中的问题 #27

baifanxxx opened this issue Oct 16, 2024 · 1 comment

Comments

@baifanxxx
Copy link

你好,

感谢作者的贡献。我成功尝试了PyramidKV在sdpa atten的情况下。但是,我发现如果采用朴素的atten,例如mistral_attn_forward_PyramidKV,将会存在一些问题。Here,要求attn_weights的size等于bsz, self.num_heads, q_len, kv_seq_len,在prefill阶段是没问题的,但是当decode时,采用被压缩后的KV cache,kv_seq_len与缓存中的KV数量不同,导致attn_weights的size与kv_seq_len不同,同理,在这里由于size不同,atten_weights与attention_mask无法相加。奇怪的是,我发现这个问题只有在朴素的atten的实现中才有,sdpa和flash atten都不存在这样的代码。

如果你有任何想法请及时回复,我将不胜感激。

@Zefan-Cai
Copy link
Owner

有可能朴素的attn确实存在bug。我们的实验是在flash attn上做的。我刚刚阅读了一下transformers上llama的modeling文件,确实只有朴素attention上,attn mask需要和attn weight相加。我们的codebase可能需要对attn mask 也做一下和attn weight相对应的reshape才能不报错。短期内,你可以注释那一行。因为inference没有batch decode,attn mask不起作用。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants