Skip to content

Commit

Permalink
Initial fix for token corruption when batching (#665)
Browse files Browse the repository at this point in the history
There are 2 problems fixed by 2 code changes in this PR.

# Cache over-allocation.

This is a small problem that causes us to over-allocate cache pages in
the KV cache. This will require further work to get service.py and
{Base,Trie}PagedAttentionCache to allocate a precise & consistent amout
of cache, but is sufficient to solve the problem at hand.

# Zero-padding of seq_len and start_position

For unused requests in a batch, seq_len and start_position are usually
filled with 0. This injects NaNs that are written to page 0.

Page index 0 serves a special padding role in our batching system. It's
used to fill unused pages for shorter requests and to pad unused
requests within a batch.

Under normal circumstances, NaNs in page 0 wouldn't be problematic since
our masking system is designed to ignore values beyond the current
token. For example, when generating token 17 with a page list of [255,
254, 0], we should never need to read from the padding page.

The issue stems from our current masking implementation. Instead of
directly ignoring values, we mask by adding negative infinity to values
before applying an exponential function. While this typically works fine
and results in zeroes, it breaks down when encountering NaN values. When
this happens, NaN values from page 0 can leak into our calculations,
resulting in token corruption.
  • Loading branch information
renxida authored Dec 10, 2024
1 parent 574861f commit 4c015d4
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,9 @@ async def run(self):
for r in self.exec_requests:
assert r.start_position == 0

extra_token_slots = 1 if is_decode else 0
bsl = max(
(r.start_position + len(r.input_token_ids)) for r in self.exec_requests
(extra_token_slots + len(r.input_token_ids)) for r in self.exec_requests
)
bsl = int(math.ceil(bsl / seq_stride) * seq_stride)
block_count = bsl // seq_stride
Expand Down Expand Up @@ -389,13 +390,17 @@ async def run(self):
if self.phase == InferencePhase.DECODE:
start_positions_host = start_positions.for_transfer()
with start_positions_host.map(discard=True) as m:
m.fill(0)
m.fill(
1
) # Pad unused requests. Must pad with nonzero value because division by 0 floods clobber page (page 0) in cache with NaN values.
m.items = [req.start_position for req in self.exec_requests]
start_positions_host.copy_to(start_positions)

seq_lens_host = seq_lens.for_transfer()
with seq_lens_host.map(discard=True) as m:
m.fill(0)
m.fill(
1
) # Pad unused requests. Must pad with nonzero value because division by 0 floods clobber page (page 0) in cache with NaN values.
m.items = [
req.start_position + len(req.input_token_ids)
for req in self.exec_requests
Expand Down

0 comments on commit 4c015d4

Please sign in to comment.