diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py new file mode 100644 index 000000000..4dde584a8 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -0,0 +1,80 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Base class for kv caches. +""" + +from typing import List +from attention_paging import PageInfo +import math + + + +class BasePagedAttentionCache: + """ + Manages lifecycle of pages (using PageInfo as handles). + + + Page States: + Caching - Page can be read by multiple threads + - Also maintains a reference count + Writing - Page is being modified by a single owner thread + + Transitions: + Caching -> Writing: When acquiring an unreferenced LRU leaf page for writing + Writing -> Caching: When writing is complete and page is released + + Thread Safety: + - Multiple readers allowed in ReadableCaching state + - Single writer exclusive access in Writing state + - Reference counting prevents eviction of in-use pages + """ + def __init__(self, page_pool, tokens_per_page): + self.page_pool = page_pool + self.tokens_per_page = tokens_per_page + + + def acquire_pages_for_tokens(self, tokens: List[int]) -> tuple[List[PageInfo], int)]: + """ + Given a list of tokens, return a list of pages and a start position to continue generation from. + + In the base implementation, this will just allocate all new pages, but in shared-kv implementations, we will fetch cached pages if applicable. + + The pages are returned in order. + + No token at idx < n_cached_token should be written to. TODO: consider enforcing this. + """ + pages_needed = math.ceil(len(tokens) / self.tokens_per_page) + pages = self.page_pool.acquire_free_pages(pages_needed) + + n_cached_tokens = 0 + + return pages, n_cached_tokens + + def publish_pages(self, tokens, pages) -> None: + """ + Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests. + + Associates the tokens with the pages, and mark them as done writing. + + It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)]. + """ + + pass # the base implementation doesn't cache unfinished requests. + + def release_pages(self, tokens, pages): + """ + Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction. + """ + # in the base implementation, the pages can be owned by 1 request max, so they can be instantly release + self.page_pool.release_pages(pages) + + + + + + diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py index 125d00f58..2991e85e3 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py @@ -24,7 +24,8 @@ class PageInfo: pool: PagePool token_offset: int # Offset within the page token_count: int # Number of tokens stored in this page - ref_count: int = 0 # Number of references to this page in the radix tree + writing: bool = False + read_ref_count: int = 0 # Number of threads that still need to read this page. When this reaches 0, page is eligible for release @dataclass diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py b/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py index d9483e749..76f9e7a32 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import List, Dict, Optional, Tuple, TypeVar, Generic from dataclasses import dataclass - +from page_pool import PagePool T = TypeVar("T") # Generic type for page data @@ -74,7 +74,7 @@ class RadixTree(Generic[T]): """ def __init__( - self, *, page_pool: Any, tokens_per_page: int, disable: bool = False + self, page_pool: PagePool, tokens_per_page: int ) -> None: """Initialize the radix tree.