Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial fix for token corruption when batching (#665)
There are 2 problems fixed by 2 code changes in this PR. # Cache over-allocation. This is a small problem that causes us to over-allocate cache pages in the KV cache. This will require further work to get service.py and {Base,Trie}PagedAttentionCache to allocate a precise & consistent amout of cache, but is sufficient to solve the problem at hand. # Zero-padding of seq_len and start_position For unused requests in a batch, seq_len and start_position are usually filled with 0. This injects NaNs that are written to page 0. Page index 0 serves a special padding role in our batching system. It's used to fill unused pages for shorter requests and to pad unused requests within a batch. Under normal circumstances, NaNs in page 0 wouldn't be problematic since our masking system is designed to ignore values beyond the current token. For example, when generating token 17 with a page list of [255, 254, 0], we should never need to read from the padding page. The issue stems from our current masking implementation. Instead of directly ignoring values, we mask by adding negative infinity to values before applying an exponential function. While this typically works fine and results in zeroes, it breaks down when encountering NaN values. When this happens, NaN values from page 0 can leak into our calculations, resulting in token corruption.
- Loading branch information