Skip to content

Commit

Permalink
might be ready for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Nov 22, 2024
1 parent ee5af50 commit 503bd71
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
import math



class BasePagedAttentionCache:
"""
Manages lifecycle of pages (using PageInfo as handles).
Page States:
Caching - Page can be read by multiple threads
- Also maintains a reference count
Expand All @@ -33,22 +32,28 @@ class BasePagedAttentionCache:
- Single writer exclusive access in Writing state
- Reference counting prevents eviction of in-use pages
"""

def __init__(self, page_pool, tokens_per_page):
self.page_pool = page_pool
self.tokens_per_page = tokens_per_page


def acquire_pages_for_tokens(self, tokens: List[int]) -> tuple[List[PageInfo], int)]:
def acquire_pages_for_tokens(
self, tokens: List[int], extra_token_slots: int = 1
) -> tuple[list[PageInfo], int]:
"""
Given a list of tokens, return a list of pages and a start position to continue generation from.
Parameters:
- tokens: all the known tokens for this generation request
- extra_token_slots: number of kvcache slots needed in addition to the ones needed to hold the given tokens.
In the base implementation, this will just allocate all new pages, but in shared-kv implementations, we will fetch cached pages if applicable.
The pages are returned in order.
No token at idx < n_cached_token should be written to. TODO: consider enforcing this.
"""
pages_needed = math.ceil(len(tokens) / self.tokens_per_page)
pages_needed = math.ceil(len(tokens + extra_token_slots) / self.tokens_per_page)
pages = self.page_pool.acquire_free_pages(pages_needed)

n_cached_tokens = 0
Expand All @@ -64,17 +69,11 @@ def publish_pages(self, tokens, pages) -> None:
It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)].
"""

pass # the base implementation doesn't cache unfinished requests.
pass # the base implementation doesn't cache unfinished requests.

def release_pages(self, tokens, pages):
"""
Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction.
"""
# in the base implementation, the pages can be owned by 1 request max, so they can be instantly release
self.page_pool.release_pages(pages)






12 changes: 10 additions & 2 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,11 @@ def board_prefills(self, cache: AttnPageCache):
needed_pages = math.ceil(
len(prefill_request.input_token_ids) / self.page_seq_stride
)
pages = cache.acquire_free_pages(needed_pages)
# allocate kv cache pages
pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens(
prefill_request.input_token_ids,
extra_token_slots=0, # prefill needs no extra kvcache slots to write to
)
if pages is None:
logger.debug("Cannot fulfill request for %d pages", needed_pages)
continue
Expand Down Expand Up @@ -254,7 +258,11 @@ def board_decodes(self, cache: AttnPageCache):
/ self.page_seq_stride
)
if needed_pages > len(decode_request.locked_pages):
pages = cache.acquire_free_pages(needed_pages)
# allocate kv cache pages
pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens(
decode_request.input_token_ids,
extra_token_slots=1, # need 1 extra slot to write result.
)
if pages is None:
logger.debug(
"Cannot fulfill decode request for %d pages", needed_pages
Expand Down

0 comments on commit 503bd71

Please sign in to comment.