Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

对比了下StreamingLLM源码实现,貌似PyramidKV中的实现位置编码这块有问题 #19

Open
free-stardust opened this issue Jul 22, 2024 · 4 comments

Comments

@free-stardust
Copy link

free-stardust commented Jul 22, 2024

你好,PyramidKV 这个实现确实很巧妙,但是有个问题:我看了下 StreamingLLM 的源码,这个实现中,由于位置编码重新排列,所以注意力实现那里在存 kv cache 之前,并没有对 k 进行位置嵌入,在存 kv cache 后才对 k 进行了临时的位置嵌入和注意力计算;但是 PyramidKV 对应的 StreamingLLM 实现,是在对 q 和 k 同时进行位置嵌入之后进行了kv cache 保存和注意力计算,结果就导致没有实现 StreamingLLM 中的 kv cache 对应位置的重新排列,省略了这块后作为对比测试,是不是有点问题?

@free-stardust free-stardust reopened this Jul 22, 2024
@free-stardust free-stardust changed the title 对比了下StreamingLLM源码实现,看着PyramidKV中的实现位置编码这块有问题,是我理解的有问题吗? 对比了下StreamingLLM源码实现,貌似PyramidKV中的实现位置编码这块有问题。。。 Jul 22, 2024
@free-stardust free-stardust changed the title 对比了下StreamingLLM源码实现,貌似PyramidKV中的实现位置编码这块有问题。。。 对比了下StreamingLLM源码实现,貌似PyramidKV中的实现位置编码这块有问题 Jul 22, 2024
@Zefan-Cai
Copy link
Owner

我推测原来的StreamingLLM为了多次编码4k长度的文本才进行如你所说的操作:在存 kv cache 后才对 k 进行了临时的位置嵌入和注意力计算。因为多次编码的话,在编码中如果加入位置编码,会导致最终前面4个token被多次添加位置编码。如果是进行longBench的测试的话,不会进行多次编码,所以和SnapKV、PyramidKV保持一样的操作应该页合理,也就是:对 q 和 k 同时进行位置嵌入之后进行kv cache 保存和注意力计算。

@free-stardust
Copy link
Author

我描述的好像有点问题,我的理解是这样的:

因为 StreamingLLM 原始版本实现里面有个 pos shift,用论文中的例子说就是一个长为 8 的序列移除其中的 4 和 5 之后,各个 token 的位置 pos_id 就变成了这样 [0, 1, 2, 3, 6, 7, 8],这种情况下计算 k 的位置编码时,就按照当前 kv cache 的当前内容长度,将后面断续的位置 pos_id 重新进行了连续性替换,变成了 [0, 1, 2, 3, 4, 5, 6],然后最新生层的 k,编进去则为 7,而不是默认实现对应的位置 id=9,之后基于 [0, 1, 2, 3, 4, 5, 6, 7] 这个序列对 kv cache 进行新的的位置编码。

如果这个序列不重新配置,那么就是 transformers 库默认的实现,即默认的 kv cache 是利用 [0, 1, 2, 3, 6, 7, 8] 这个 pos_id 值进行的位置编码,然后新生成时,对于 pos_id=9 的 token,则是按照 9 进行编码。

因此默认情况下位置编码 k 时,即使 kv cache 中的 token 会被逐渐被 evict,但他们位置编码 pos_id 实际就是各种变动的不连续序列,而 StreamingLLM 对应的 pos shift 实现里面为了达到位置动态配置,保存的 k 是无位置编码的,只在注意力计算时才会按照当前的 kv cache 长度对整个 kv cache 中的 k 进行位置编码嵌入,这个如果在存 kv cache 前就进行位置编码的话,确实会重复添加位置编码,但这种种情况下,不止是前面 4 个会重复添加位置编码,后续的所有 token 都会被重复添加位置编码,但这实际也可以通过切片进行规避,而且默认的注意力实现也只是计算最新的 token 的 k 并对其进行编码存入 cache,所以我理解的是,他这么做的目的,就是为了对 evict 之后的 kv cache 的各个 token 相对位置重新配置,所以进行了这个操作。

但如果不进行这个重新配置的话,就是采用默认的注意力实现,位置 pos_id 不会重新配置,就是按照正常的 pos_id 进行对 k 进行位置编码,其实确实也可以,毕竟 StreamingLLM 把这个 pos shift 作为了可配置项。

然后纠结这个是因为我自己测试的时候,发现他这个 trick 效果确实会好一点,而且他论文里面也对这个进行了详细叙述,但是看 PyramidKV 的实现里面没有这个部分,有点疑惑。

@Zefan-Cai
Copy link
Owner

我确实是没注意到这个trick。这个问题应该分为两个方面:

1.保存的kv cache不添加位置编码,应该确实是因为会引入重复添加位置编码的问题。如果每次保存都添加位置编码,前四个token和local window都会被多次添加位置编码;而且也不方便在保存后在下一次encode时输入pos ⇧以后得到的新的位置编码。这个操作和其他的kv cache compression工作是不同的。其他的单次编码的工作可以在保存以前就给kv cache添加位置编码,不会影响结果。

  1. 位置编码偏移可能确实是一个trick,我之前没有注意。其他的工作一般没有使用改动的位置编码。直觉上讲,不改动位置编码效果应该会好一些,因为这样模型能对每个token的位置有正确的感知。StreamingLLM使用这个trick的原因,我推测是它需要用4k context length的LLM编码4k长度的内容多次,如果使用没有pos shift的话可能会使用比4k大得多的位置编码。这样的话,LLM会接触到一个没有在训练时见过的位置编码,因此对效果有负面影响。

@free-stardust
Copy link
Author

确实,如果按照 StreamingLLM 这种实现,确实会存在重复位置编码的情况,不过从他这个代码逻辑上看,他这种实现应该就是为了服务他的 trick,相比之下其他的 kv cache 压缩是在默认的增量更新基础上进行的修改,所以不会影响结果。

然后关于第二方面,这个我也感觉很奇怪,按 RoPE 的原理来说,应该是保存相对位置更好,并且这个理论上是具有更好的外推性,但是在 StreamingLLM 重新编码位置后,相对位置信息改变了,生成效果居然好了,这确实有点神奇,我第一反应是位置编码算法特性的问题,您说的这个推测,我还真没想到,学习了,感谢感谢!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants