diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index 4dde584a8..e02d55851 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -13,12 +13,11 @@ import math - class BasePagedAttentionCache: """ Manages lifecycle of pages (using PageInfo as handles). - + Page States: Caching - Page can be read by multiple threads - Also maintains a reference count @@ -33,22 +32,28 @@ class BasePagedAttentionCache: - Single writer exclusive access in Writing state - Reference counting prevents eviction of in-use pages """ + def __init__(self, page_pool, tokens_per_page): self.page_pool = page_pool self.tokens_per_page = tokens_per_page - - def acquire_pages_for_tokens(self, tokens: List[int]) -> tuple[List[PageInfo], int)]: + def acquire_pages_for_tokens( + self, tokens: List[int], extra_token_slots: int = 1 + ) -> tuple[list[PageInfo], int]: """ Given a list of tokens, return a list of pages and a start position to continue generation from. + Parameters: + - tokens: all the known tokens for this generation request + - extra_token_slots: number of kvcache slots needed in addition to the ones needed to hold the given tokens. + In the base implementation, this will just allocate all new pages, but in shared-kv implementations, we will fetch cached pages if applicable. The pages are returned in order. No token at idx < n_cached_token should be written to. TODO: consider enforcing this. """ - pages_needed = math.ceil(len(tokens) / self.tokens_per_page) + pages_needed = math.ceil(len(tokens + extra_token_slots) / self.tokens_per_page) pages = self.page_pool.acquire_free_pages(pages_needed) n_cached_tokens = 0 @@ -64,7 +69,7 @@ def publish_pages(self, tokens, pages) -> None: It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)]. """ - pass # the base implementation doesn't cache unfinished requests. + pass # the base implementation doesn't cache unfinished requests. def release_pages(self, tokens, pages): """ @@ -72,9 +77,3 @@ def release_pages(self, tokens, pages): """ # in the base implementation, the pages can be owned by 1 request max, so they can be instantly release self.page_pool.release_pages(pages) - - - - - - diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index bcd08b756..efbcfc747 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -218,7 +218,11 @@ def board_prefills(self, cache: AttnPageCache): needed_pages = math.ceil( len(prefill_request.input_token_ids) / self.page_seq_stride ) - pages = cache.acquire_free_pages(needed_pages) + # allocate kv cache pages + pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( + prefill_request.input_token_ids, + extra_token_slots=0, # prefill needs no extra kvcache slots to write to + ) if pages is None: logger.debug("Cannot fulfill request for %d pages", needed_pages) continue @@ -254,7 +258,11 @@ def board_decodes(self, cache: AttnPageCache): / self.page_seq_stride ) if needed_pages > len(decode_request.locked_pages): - pages = cache.acquire_free_pages(needed_pages) + # allocate kv cache pages + pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( + decode_request.input_token_ids, + extra_token_slots=1, # need 1 extra slot to write result. + ) if pages is None: logger.debug( "Cannot fulfill decode request for %d pages", needed_pages