From 4a7a82afd9ea5b4c96871f35af42d939bc1e45b0 Mon Sep 17 00:00:00 2001 From: Cedar Date: Tue, 12 Nov 2024 21:20:19 -0800 Subject: [PATCH 01/23] some copy pasta boilerplate --- .../components/kvcache/attention_paging.py | 17 + .../llm/components/kvcache/radix_tree.py | 364 ++++++++++++++++++ 2 files changed, 381 insertions(+) create mode 100644 shortfin/python/shortfin_apps/llm/components/kvcache/attention_paging.py create mode 100644 shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/attention_paging.py b/shortfin/python/shortfin_apps/llm/components/kvcache/attention_paging.py new file mode 100644 index 000000000..3dee22ff6 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/attention_paging.py @@ -0,0 +1,17 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Sequence + +import logging +import math +import threading + +import shortfin as sf + +from .config_struct import ModelParams, human_size + +logger = logging.getLogger(__name__) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py b/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py new file mode 100644 index 000000000..add5c352f --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py @@ -0,0 +1,364 @@ +from __future__ import annotations +from typing import List, Tuple, Optional, Sequence +import threading +import logging +import shortfin as sf +from dataclasses import dataclass + +from ..config_struct import human_size +import math + +import time + +logger = logging.getLogger(__name__) + + +@dataclass +class PageInfo: + """ + Page index with some metadata about its contents. + """ + + page_index: int + in_use: bool + pool: PagePool + token_offset: int # Offset within the page + token_count: int # Number of tokens stored in this page + ref_count: int = 0 # Number of references to this page in the radix tree + + +@dataclass +class PagePoolConfig: + """ + Hyperparameters for the page pool. + """ + + device_block_count: int + dtype: sf.dtype + alloc_page_count: int + + paged_kv_block_size_elements: int # size of a single page as # of elements + # (e.g. one configuration for llama3.1 8b hax 32x2x16x8x128=1048576 elements where: + # 32: number of transformer blocks + # 2: one for k + one for v + # 16: tokens per page + # 8: head count (32 heads, but every 4 heads share the same kv buffer) + # 128: hidden dimension + + +class PagePool: + """Page table based attention cache. + + While internal to a model, the cache is organized with additional structure + per page, outside of the model, it is just a list of pages of a certain + element type and number of elements (all inner dims are flattened). + + One page table is allocated per device in a fiber. Currently, this is a + dense allocation with committed memory but in the future, we may just + allocate the address space and lazily populate it with committed memory. + + The cache is unique because usage of it can span fibers and concurrency + is implicitly managed at the block level (i.e. freshly acquired blocks + are assumed to be uninitialized and available immediately for use). + + It is initialized with a discrete list of fiberd devices from a fiber but + cache usage can be done from any fiber which includes those devices. + """ + + def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig): + self._lock = threading.Lock() + self.devices = list(devices) + self.config = config + self.page_tables: list[sf.array.device_array] = [] + + # Setup accounting structs. + self.attn_page_entries = [ + PageInfo( + page_index=i, + in_use=False, + pool=self, + token_offset=0, + token_count=0, + ref_count=0, + ) + for i in range(self.config.alloc_page_count) + ] + + self.attn_page_free = list(self.attn_page_entries) + + # Initialize a page table on each device. + page_table_shape = [ + self.config.alloc_page_count, + self.config.paged_kv_block_size_elements, + ] + for device in devices: + logging.info( + "Allocating page table (shape=%r, dtype=%r, size=%s) on %r", + page_table_shape, + self.config.dtype, + human_size( + math.prod(page_table_shape) * self.config.dtype.dense_size_bytes + ), + device, + ) + page_table = sf.array.device_array.for_device( + device, page_table_shape, self.config.dtype + ) + self.page_tables.append(page_table) + + def acquire_free_pages(self, count: int) -> list[PageInfo] | None: + with self._lock: + available = len(self.attn_page_free) + if count > available: + return None + return [self.attn_page_free.pop() for _ in range(count)] + + def release_pages(self, pages: list[PageInfo]): + with self._lock: + self.attn_page_free.extend(pages) + + def __repr__(self): + # No need to lock for repr (list is internally synchronized). + free_pages = len(self.attn_page_free) + total_pages = len(self.attn_page_entries) + return ( + f"AttnPageCache({total_pages - free_pages}/{total_pages} pages in use: " + f"{100.0 * free_pages / total_pages}% free)" + ) + + +############################## begin radix attention + + +@dataclass +class RadixNode: + """Node in radix tree tracking pages""" + + children: dict[int, RadixNode] + parent: Optional[RadixNode] + key: List[int] + pages: List[PageInfo] + last_access_timestamp: int = 0 + ref_count: int = 0 + + +class RadixTree: + """ + Radix Tree for mapping token sequences to pages in the attention cache. + + Requests pages from a PagePool to store kvs for tokens in the sequence. + """ + + def __init__( + self, *, page_pool: PagePool, tokens_per_page: int, disable: bool = False + ): + self._lock = threading.Lock() + self.page_pool = page_pool + self.disable = disable + self.tokens_per_page = tokens_per_page + self.reset() + + def reset(self) -> None: + """Reset the cache state""" + with self._lock: + # free + self.root = RadixNode( + children={}, parent=None, key=[], pages=[], ref_count=1 + ) + + def _get_match_len(self, key1: List[int], key2: List[int]) -> int: + """Return length of matching prefix between two keys""" + for i, (k1, k2) in enumerate(zip(key1, key2)): + if k1 != k2: + return i + return min(len(key1), len(key2)) + + def match_prefix(self, token_ids: List[int]) -> Tuple[List[PageInfo], RadixNode]: + """Find longest matching prefix and return its pages""" + if self.disable: + return [], self.root + + with self._lock: + matched_pages = [] + last_node = self.root + curr_node = self.root + remaining_tokens = token_ids + + while remaining_tokens: + first_token = remaining_tokens[0] + if first_token not in curr_node.children: + break + + child = curr_node.children[first_token] + match_len = self._get_match_len(child.key, remaining_tokens) + + if match_len < len(child.key): + # Partial match - need to split + new_node = self._split_node(child, match_len) + matched_pages.extend(new_node.pages) + last_node = new_node + break + else: + # Full match of this node + matched_pages.extend(child.pages) + last_node = child + remaining_tokens = remaining_tokens[match_len:] + curr_node = child + + # Update access time and ref counts + self._update_access_path(last_node) + for page in matched_pages: + page.ref_count += 1 + + return matched_pages, last_node + + def _split_node(self, node: RadixNode, split_pos: int) -> RadixNode: + """Split a node at given position, return new intermediate node""" + new_node = RadixNode( + children={}, + parent=node.parent, + key=node.key[:split_pos], + pages=node.pages[:split_pos], + ref_count=node.ref_count, + ) + + # Update the original node + node.parent.children[node.key[0]] = new_node + node.key = node.key[split_pos:] + node.pages = node.pages[split_pos:] + node.parent = new_node + new_node.children[node.key[0]] = node + + return new_node + + def _update_access_path(self, node: RadixNode) -> None: + """Update access timestamp along path to root""" + current_time = int(time.time()) + while node is not None: + node.last_access_timestamp = current_time + node = node.parent + + def cache_sequence( + self, token_ids: List[int], existing_pages: Optional[List[PageInfo]] = None + ) -> RadixNode: + """Cache a token sequence, potentially extending existing pages""" + with self._lock: + if existing_pages: + total_cached_tokens = sum(p.token_count for p in existing_pages) + new_tokens = token_ids[total_cached_tokens:] + if new_tokens: + new_pages = self._allocate_pages(len(new_tokens)) + pages = existing_pages + new_pages + else: + pages = existing_pages + else: + pages = self._allocate_pages(len(token_ids)) + + return self._insert_sequence(token_ids, pages) + + def _allocate_pages(self, token_count: int) -> List[PageInfo]: + """Allocate pages needed for token sequence""" + pages_needed = (token_count + self.tokens_per_page - 1) // self.tokens_per_page + page_entries = self.page_pool.acquire_free_pages(pages_needed) + + if not page_entries: + self._evict_pages(pages_needed) + page_entries = self.page_pool.acquire_free_pages(pages_needed) + if not page_entries: + raise RuntimeError( + f"Failed to allocate {pages_needed} pages after eviction" + ) + + pages = [] + tokens_remaining = token_count + for entry in page_entries: + tokens_in_page = min(self.tokens_per_page, tokens_remaining) + pages.append( + PageInfo( + page=entry, token_offset=0, token_count=tokens_in_page, ref_count=1 + ) + ) + tokens_remaining -= tokens_in_page + + return pages + + def _insert_sequence( + self, token_ids: List[int], pages: List[PageInfo] + ) -> RadixNode: + """Insert a sequence into the radix tree""" + curr_node = self.root + remaining_tokens = token_ids + + while remaining_tokens: + first_token = remaining_tokens[0] + if first_token not in curr_node.children: + # Create new leaf node + new_node = RadixNode( + children={}, + parent=curr_node, + key=remaining_tokens, + pages=pages[len(token_ids) - len(remaining_tokens) :], + ref_count=1, + ) + curr_node.children[first_token] = new_node + return new_node + + child = curr_node.children[first_token] + match_len = self._get_match_len(child.key, remaining_tokens) + + if match_len < len(child.key): + # Split existing node + split_node = self._split_node(child, match_len) + if match_len < len(remaining_tokens): + # Create new node for remaining tokens + new_node = RadixNode( + children={}, + parent=split_node, + key=remaining_tokens[match_len:], + pages=pages[ + len(token_ids) - len(remaining_tokens) + match_len : + ], + ref_count=1, + ) + split_node.children[remaining_tokens[match_len]] = new_node + return new_node + return split_node + + remaining_tokens = remaining_tokens[match_len:] + curr_node = child + + return curr_node + + def _evict_pages(self, pages_needed: int) -> None: + """Evict pages using LRU strategy""" + # Collect all nodes + nodes = [] + stack = [self.root] + while stack: + node = stack.pop() + stack.extend(node.children.values()) + if not node.children: # leaf node + nodes.append(node) + + # Sort by access time + nodes.sort(key=lambda n: n.last_access_timestamp) + + pages_freed = 0 + for node in nodes: + if node.ref_count == 0: + freeable_pages = [p for p in node.pages if p.ref_count == 0] + self.page_pool.release_pages([p.page for p in freeable_pages]) + pages_freed += len(freeable_pages) + + # Remove node if all pages freed + if len(freeable_pages) == len(node.pages): + del node.parent.children[node.key[0]] + + if pages_freed >= pages_needed: + break + + def release_pages(self, pages: List[PageInfo]) -> None: + """Release references to pages""" + with self._lock: + for page in pages: + page.ref_count -= 1 From 78c2c8f508076e6d02a4ce6642f7ceb1d3ae0ec9 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 13 Nov 2024 10:02:00 -0800 Subject: [PATCH 02/23] check in what i have now --- .../llm/components/kvcache/radix_tree.py | 76 ++++++++++++++++++- 1 file changed, 74 insertions(+), 2 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py b/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py index add5c352f..e82d75ad0 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py @@ -117,6 +117,36 @@ def release_pages(self, pages: list[PageInfo]): with self._lock: self.attn_page_free.extend(pages) + def copy_page(self, src_page: PageInfo) -> PageInfo: + """ + Copy a page's contents to a new page. + + Args: + src_page: Source page to copy from + token_count: Optional number of tokens to copy. If None, copies all tokens. + + Returns: + New PageInfo containing the copied data + """ + with self._lock: + # Allocate new page + (dst_page,) = self.acquire_free_pages(1) + + # Copy the data on each device + for page_table in self.page_tables: + # View of source and destination pages + src_view = page_table.view((src_page.page_index,)) + dst_view = page_table.view((dst_page.page_index,)) + + # Copy the data + dst_view.copy_from(src_view) + + # Setup destination page metadata + dst_page.in_use = True + dst_page.token_offset = 0 # Always start at beginning of new page + + return dst_page + def __repr__(self): # No need to lock for repr (list is internally synchronized). free_pages = len(self.attn_page_free) @@ -214,21 +244,63 @@ def match_prefix(self, token_ids: List[int]) -> Tuple[List[PageInfo], RadixNode] def _split_node(self, node: RadixNode, split_pos: int) -> RadixNode: """Split a node at given position, return new intermediate node""" + # Calculate which page contains the split point + current_pos = 0 + split_page_idx = 0 + tokens_in_split_page = 0 + + for idx, page in enumerate(node.pages): + if current_pos + page.token_count > split_pos: + # This page contains our split point + split_page_idx = idx + tokens_in_split_page = split_pos - current_pos + break + current_pos += page.token_count + + # Handle the page containing split point + split_page = node.pages[split_page_idx] + if tokens_in_split_page > 0 and tokens_in_split_page < split_page.token_count: + # Need to copy-on-write this page + new_page = self.page_pool.copy_page( + split_page, token_count=tokens_in_split_page + ) + + # Adjust original page + split_page.token_offset += tokens_in_split_page + split_page.token_count -= tokens_in_split_page + + # Build new pages list with the copy + new_pages = ( + node.pages[:split_page_idx] + + [new_page] # Full pages before split + + node.pages[ # Copied partial page + split_page_idx + 1 : + ] # Any remaining full pages + ) + else: + # Split aligned with page boundary + new_pages = node.pages[:split_pos] + + # Create new node new_node = RadixNode( children={}, parent=node.parent, key=node.key[:split_pos], - pages=node.pages[:split_pos], + pages=new_pages, ref_count=node.ref_count, ) # Update the original node node.parent.children[node.key[0]] = new_node node.key = node.key[split_pos:] - node.pages = node.pages[split_pos:] + node.pages = node.pages[split_page_idx:] # Keep remaining pages node.parent = new_node new_node.children[node.key[0]] = node + # Increment ref counts for shared pages + for page in new_pages[:-1]: # All but last page are shared + page.ref_count += 1 + return new_node def _update_access_path(self, node: RadixNode) -> None: From e2e806fc69b11dbb9d26afbe543f21fb96b53852 Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 18 Nov 2024 14:26:45 -0800 Subject: [PATCH 03/23] split page pool and radix tree; add page pool test --- .../llm/components/kvcache/page_pool.py | 159 +++++ .../llm/components/kvcache/radix_tree.py | 600 +++++++----------- .../llm/components/kvcache/page_pool_test.py | 77 +++ 3 files changed, 450 insertions(+), 386 deletions(-) create mode 100644 shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py create mode 100644 shortfin/tests/apps/llm/components/kvcache/page_pool_test.py diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py new file mode 100644 index 000000000..125d00f58 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py @@ -0,0 +1,159 @@ +from __future__ import annotations +from typing import List, Tuple, Optional, Sequence +import threading +import logging +import shortfin as sf +import shortfin.array as sfnp +from dataclasses import dataclass + +from ..config_struct import human_size +import math + +import time + +logger = logging.getLogger(__name__) + + +@dataclass +class PageInfo: + """ + Page index with some metadata about its contents. + """ + + page_index: int + pool: PagePool + token_offset: int # Offset within the page + token_count: int # Number of tokens stored in this page + ref_count: int = 0 # Number of references to this page in the radix tree + + +@dataclass +class PagePoolConfig: + """ + Hyperparameters for the page pool. + """ + + dtype: sf.dtype + alloc_page_count: int + + paged_kv_block_size_elements: int # size of a single page as # of elements + # (e.g. one configuration for llama3.1 8b hax 32x2x16x8x128=1048576 elements where: + # 32: number of transformer blocks + # 2: one for k + one for v + # 16: tokens per page + # 8: head count (32 heads, but every 4 heads share the same kv buffer) + # 128: hidden dimension + + +class PagePool: + """Page table based attention cache. + + While internal to a model, the cache is organized with additional structure + per page, outside of the model, it is just a list of pages of a certain + element type and number of elements (all inner dims are flattened). + + One page table is allocated per device in a fiber. Currently, this is a + dense allocation with committed memory but in the future, we may just + allocate the address space and lazily populate it with committed memory. + + The cache is unique because usage of it can span fibers and concurrency + is implicitly managed at the block level (i.e. freshly acquired blocks + are assumed to be uninitialized and available immediately for use). + + It is initialized with a discrete list of fiberd devices from a fiber but + cache usage can be done from any fiber which includes those devices. + + In addition to supporting paged attention standalone, this also serves + as the array / buffer allocation layer for radix attention described in + `radix_tree.py`. + """ + + def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig): + self._lock = threading.Lock() + self.devices = list(devices) + self.config = config + self.page_tables: list[sf.array.device_array] = [] + + # Setup accounting structs. + self.attn_page_entries = [ + PageInfo( + page_index=i, + pool=self, + token_offset=0, + token_count=0, + ref_count=0, + ) + for i in range(self.config.alloc_page_count) + ] + + self.attn_page_free = list(self.attn_page_entries) + + # Initialize a page table on each device. + page_table_shape = [ + self.config.alloc_page_count, + self.config.paged_kv_block_size_elements, + ] + for device in devices: + logging.info( + "Allocating page table (shape=%r, dtype=%r, size=%s) on %r", + page_table_shape, + self.config.dtype, + human_size(config.dtype.compute_dense_nd_size(page_table_shape)), + device, + ) + page_table = sf.array.device_array.for_device( + device, page_table_shape, self.config.dtype + ) + self.page_tables.append(page_table) + + def acquire_free_pages(self, count: int) -> list[PageInfo] | None: + with self._lock: + available = len(self.attn_page_free) + if count > available: + return None + return [self.attn_page_free.pop() for _ in range(count)] + + def release_pages(self, pages: list[PageInfo]): + with self._lock: + self.attn_page_free.extend(pages) + + def copy_page(self, src_page: PageInfo) -> PageInfo: + """ + Copy a page's contents to a new page. + + Args: + src_page: Source page to copy from + token_count: Optional number of tokens to copy. If None, copies all tokens. + + Returns: + New PageInfo containing the copied data + """ + # Allocate new page + (dst_page,) = self.acquire_free_pages(1) + + # fill src page with data + + # Copy the data on each device + for page_table in self.page_tables: + # View of source and destination pages + src_view = page_table.view(src_page.page_index) + dst_view = page_table.view(dst_page.page_index) + # Copy the data + dst_view.copy_from(src_view) + + # Setup destination page metadata + dst_page.token_offset = 0 # Always start at beginning of new page + + return dst_page + + def __repr__(self): + # No need to lock for repr (list is internally synchronized). + free_pages = len(self.attn_page_free) + total_pages = len(self.attn_page_entries) + return ( + f"AttnPageCache({total_pages - free_pages}/{total_pages} pages in use: " + f"{100.0 * free_pages / total_pages}% free)" + ) + + +############################## begin radix attention diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py b/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py index e82d75ad0..d9483e749 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py @@ -1,436 +1,264 @@ from __future__ import annotations -from typing import List, Tuple, Optional, Sequence -import threading -import logging -import shortfin as sf +from typing import List, Dict, Optional, Tuple, TypeVar, Generic from dataclasses import dataclass -from ..config_struct import human_size -import math - -import time - -logger = logging.getLogger(__name__) +T = TypeVar("T") # Generic type for page data @dataclass -class PageInfo: - """ - Page index with some metadata about its contents. - """ - - page_index: int - in_use: bool - pool: PagePool - token_offset: int # Offset within the page - token_count: int # Number of tokens stored in this page - ref_count: int = 0 # Number of references to this page in the radix tree +class RadixNode(Generic[T]): + """A node in the radix tree that tracks cached pages of data. + + Each node represents a sequence of tokens and maintains references to the pages + containing the cached data for those tokens. The node structure allows for + efficient prefix matching and sharing of cached data. + + Attributes: + children: Mapping of first token to child nodes + parent: Reference to parent node, None for root + key: Token sequence this node represents + pages: List of page data associated with this token sequence + last_access_timestamp: Unix timestamp of last access + ref_count: Number of active references to this node + + Example: + ```python + # Create a leaf node for token sequence [5, 2, 8] + node = RadixNode( + children={}, + parent=parent_node, + key=[5, 2, 8], + pages=[page1, page2], + ref_count=1 + ) + # Access timestamp is automatically updated when node is accessed + assert node.last_access_timestamp > 0 -@dataclass -class PagePoolConfig: - """ - Hyperparameters for the page pool. + # When done with node, decrement reference count + node.ref_count -= 1 + ``` """ - device_block_count: int - dtype: sf.dtype - alloc_page_count: int + children: Dict[int, RadixNode[T]] + parent: Optional[RadixNode[T]] + key: List[int] + pages: List[T] + last_access_timestamp: int = 0 + ref_count: int = 0 - paged_kv_block_size_elements: int # size of a single page as # of elements - # (e.g. one configuration for llama3.1 8b hax 32x2x16x8x128=1048576 elements where: - # 32: number of transformer blocks - # 2: one for k + one for v - # 16: tokens per page - # 8: head count (32 heads, but every 4 heads share the same kv buffer) - # 128: hidden dimension +class RadixTree(Generic[T]): + """A radix tree implementation for caching token sequence data. -class PagePool: - """Page table based attention cache. + The tree efficiently stores and retrieves cached data for token sequences, + enabling prefix sharing and fast lookups. It handles memory management through + reference counting and LRU eviction. - While internal to a model, the cache is organized with additional structure - per page, outside of the model, it is just a list of pages of a certain - element type and number of elements (all inner dims are flattened). + Example: + ```python + # Initialize tree with a page pool and 16 tokens per page + tree = RadixTree(page_pool=my_pool, tokens_per_page=16) - One page table is allocated per device in a fiber. Currently, this is a - dense allocation with committed memory but in the future, we may just - allocate the address space and lazily populate it with committed memory. + # Cache a sequence of tokens with their associated data + token_ids = [1, 5, 8, 2] + node = tree.cache_sequence(token_ids) - The cache is unique because usage of it can span fibers and concurrency - is implicitly managed at the block level (i.e. freshly acquired blocks - are assumed to be uninitialized and available immediately for use). + # Later, find cached data for a prefix + pages, match_node = tree.match_prefix([1, 5, 8]) + assert len(pages) > 0 - It is initialized with a discrete list of fiberd devices from a fiber but - cache usage can be done from any fiber which includes those devices. + # When done with the cached data, release it + tree.release_pages(pages) + ``` """ - def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig): - self._lock = threading.Lock() - self.devices = list(devices) - self.config = config - self.page_tables: list[sf.array.device_array] = [] - - # Setup accounting structs. - self.attn_page_entries = [ - PageInfo( - page_index=i, - in_use=False, - pool=self, - token_offset=0, - token_count=0, - ref_count=0, - ) - for i in range(self.config.alloc_page_count) - ] - - self.attn_page_free = list(self.attn_page_entries) - - # Initialize a page table on each device. - page_table_shape = [ - self.config.alloc_page_count, - self.config.paged_kv_block_size_elements, - ] - for device in devices: - logging.info( - "Allocating page table (shape=%r, dtype=%r, size=%s) on %r", - page_table_shape, - self.config.dtype, - human_size( - math.prod(page_table_shape) * self.config.dtype.dense_size_bytes - ), - device, - ) - page_table = sf.array.device_array.for_device( - device, page_table_shape, self.config.dtype + def __init__( + self, *, page_pool: Any, tokens_per_page: int, disable: bool = False + ) -> None: + """Initialize the radix tree. + + Args: + page_pool: Pool that manages the underlying page allocations + tokens_per_page: Number of tokens that can be stored in each page + disable: If True, disables caching (useful for testing) + + Example: + ```python + tree = RadixTree( + page_pool=PagePool(...), + tokens_per_page=16, + disable=False ) - self.page_tables.append(page_table) + ``` + """ + raise NotImplementedError() + + def reset(self) -> None: + """Reset the tree to initial empty state. - def acquire_free_pages(self, count: int) -> list[PageInfo] | None: - with self._lock: - available = len(self.attn_page_free) - if count > available: - return None - return [self.attn_page_free.pop() for _ in range(count)] + Releases all cached pages and resets the tree to contain only a root node. - def release_pages(self, pages: list[PageInfo]): - with self._lock: - self.attn_page_free.extend(pages) + Example: + ```python + tree = RadixTree(...) - def copy_page(self, src_page: PageInfo) -> PageInfo: + # Cache some sequences + tree.cache_sequence([1, 2, 3]) + tree.cache_sequence([4, 5, 6]) + + # Reset tree to clean state + tree.reset() + + # Tree is now empty except for root node + pages, _ = tree.match_prefix([1, 2, 3]) + assert len(pages) == 0 + ``` """ - Copy a page's contents to a new page. + raise NotImplementedError() + + def match_prefix(self, token_ids: List[int]) -> Tuple[List[T], RadixNode[T]]: + """Find the longest matching prefix and return its cached pages. Args: - src_page: Source page to copy from - token_count: Optional number of tokens to copy. If None, copies all tokens. + token_ids: Sequence of tokens to match against Returns: - New PageInfo containing the copied data + Tuple containing: + - List of cached pages for the matching prefix + - The node containing the last matched token + + Example: + ```python + # Cache a sequence + tree.cache_sequence([1, 2, 3, 4, 5]) + + # Match a prefix + pages, node = tree.match_prefix([1, 2, 3]) + + # pages contains cached data for tokens [1, 2, 3] + assert len(pages) > 0 + + # node represents the position after [1, 2, 3] + assert node.key == [1, 2, 3] + + # Don't forget to release when done + tree.release_pages(pages) + ``` """ - with self._lock: - # Allocate new page - (dst_page,) = self.acquire_free_pages(1) - - # Copy the data on each device - for page_table in self.page_tables: - # View of source and destination pages - src_view = page_table.view((src_page.page_index,)) - dst_view = page_table.view((dst_page.page_index,)) - - # Copy the data - dst_view.copy_from(src_view) - - # Setup destination page metadata - dst_page.in_use = True - dst_page.token_offset = 0 # Always start at beginning of new page - - return dst_page - - def __repr__(self): - # No need to lock for repr (list is internally synchronized). - free_pages = len(self.attn_page_free) - total_pages = len(self.attn_page_entries) - return ( - f"AttnPageCache({total_pages - free_pages}/{total_pages} pages in use: " - f"{100.0 * free_pages / total_pages}% free)" - ) + raise NotImplementedError() + def cache_sequence( + self, token_ids: List[int], existing_pages: Optional[List[T]] = None + ) -> RadixNode[T]: + """Cache a token sequence, potentially extending existing cached pages. -############################## begin radix attention + Args: + token_ids: Complete sequence of tokens to cache + existing_pages: Optional list of already cached pages to extend + Returns: + Node containing the cached sequence -@dataclass -class RadixNode: - """Node in radix tree tracking pages""" + Example: + ```python + # Cache initial sequence + node1 = tree.cache_sequence([1, 2, 3]) - children: dict[int, RadixNode] - parent: Optional[RadixNode] - key: List[int] - pages: List[PageInfo] - last_access_timestamp: int = 0 - ref_count: int = 0 + # Match prefix and extend with new tokens + pages, _ = tree.match_prefix([1, 2, 3]) + node2 = tree.cache_sequence([1, 2, 3, 4, 5], existing_pages=pages) + # New node contains extended sequence + assert node2.key == [1, 2, 3, 4, 5] -class RadixTree: - """ - Radix Tree for mapping token sequences to pages in the attention cache. + # Release pages when done + tree.release_pages(pages) + ``` + """ + raise NotImplementedError() - Requests pages from a PagePool to store kvs for tokens in the sequence. - """ + def release_pages(self, pages: List[T]) -> None: + """Release references to cached pages. - def __init__( - self, *, page_pool: PagePool, tokens_per_page: int, disable: bool = False - ): - self._lock = threading.Lock() - self.page_pool = page_pool - self.disable = disable - self.tokens_per_page = tokens_per_page - self.reset() + Decrements reference counts and potentially frees memory if counts reach zero. - def reset(self) -> None: - """Reset the cache state""" - with self._lock: - # free - self.root = RadixNode( - children={}, parent=None, key=[], pages=[], ref_count=1 - ) + Args: + pages: List of pages to release + + Example: + ```python + # Get cached pages + pages, _ = tree.match_prefix([1, 2, 3]) + + # Use the pages... + + # Release when done + tree.release_pages(pages) + ``` + """ + raise NotImplementedError() def _get_match_len(self, key1: List[int], key2: List[int]) -> int: - """Return length of matching prefix between two keys""" - for i, (k1, k2) in enumerate(zip(key1, key2)): - if k1 != k2: - return i - return min(len(key1), len(key2)) - - def match_prefix(self, token_ids: List[int]) -> Tuple[List[PageInfo], RadixNode]: - """Find longest matching prefix and return its pages""" - if self.disable: - return [], self.root - - with self._lock: - matched_pages = [] - last_node = self.root - curr_node = self.root - remaining_tokens = token_ids - - while remaining_tokens: - first_token = remaining_tokens[0] - if first_token not in curr_node.children: - break - - child = curr_node.children[first_token] - match_len = self._get_match_len(child.key, remaining_tokens) - - if match_len < len(child.key): - # Partial match - need to split - new_node = self._split_node(child, match_len) - matched_pages.extend(new_node.pages) - last_node = new_node - break - else: - # Full match of this node - matched_pages.extend(child.pages) - last_node = child - remaining_tokens = remaining_tokens[match_len:] - curr_node = child - - # Update access time and ref counts - self._update_access_path(last_node) - for page in matched_pages: - page.ref_count += 1 - - return matched_pages, last_node - - def _split_node(self, node: RadixNode, split_pos: int) -> RadixNode: - """Split a node at given position, return new intermediate node""" - # Calculate which page contains the split point - current_pos = 0 - split_page_idx = 0 - tokens_in_split_page = 0 - - for idx, page in enumerate(node.pages): - if current_pos + page.token_count > split_pos: - # This page contains our split point - split_page_idx = idx - tokens_in_split_page = split_pos - current_pos - break - current_pos += page.token_count - - # Handle the page containing split point - split_page = node.pages[split_page_idx] - if tokens_in_split_page > 0 and tokens_in_split_page < split_page.token_count: - # Need to copy-on-write this page - new_page = self.page_pool.copy_page( - split_page, token_count=tokens_in_split_page - ) + """Return length of matching prefix between two keys. - # Adjust original page - split_page.token_offset += tokens_in_split_page - split_page.token_count -= tokens_in_split_page - - # Build new pages list with the copy - new_pages = ( - node.pages[:split_page_idx] - + [new_page] # Full pages before split - + node.pages[ # Copied partial page - split_page_idx + 1 : - ] # Any remaining full pages - ) - else: - # Split aligned with page boundary - new_pages = node.pages[:split_pos] + Args: + key1: First sequence of tokens + key2: Second sequence of tokens - # Create new node - new_node = RadixNode( - children={}, - parent=node.parent, - key=node.key[:split_pos], - pages=new_pages, - ref_count=node.ref_count, - ) + Returns: + Length of the matching prefix + + Example: + ```python + # Internal use for finding split points + length = tree._get_match_len([1, 2, 3, 4], [1, 2, 5, 6]) + assert length == 2 # Matches [1, 2] + ``` + """ + raise NotImplementedError() - # Update the original node - node.parent.children[node.key[0]] = new_node - node.key = node.key[split_pos:] - node.pages = node.pages[split_page_idx:] # Keep remaining pages - node.parent = new_node - new_node.children[node.key[0]] = node + def _split_node(self, node: RadixNode[T], split_pos: int) -> RadixNode[T]: + """Split a node at the given position. - # Increment ref counts for shared pages - for page in new_pages[:-1]: # All but last page are shared - page.ref_count += 1 + Args: + node: Node to split + split_pos: Position in the node's key where split should occur - return new_node + Returns: + New intermediate node created by the split - def _update_access_path(self, node: RadixNode) -> None: - """Update access timestamp along path to root""" - current_time = int(time.time()) - while node is not None: - node.last_access_timestamp = current_time - node = node.parent + Example: + ```python + # Internal use during insertion + # If we have a node with key [1, 2, 3, 4] and need to + # insert [1, 2, 5, 6], we first split at position 2: - def cache_sequence( - self, token_ids: List[int], existing_pages: Optional[List[PageInfo]] = None - ) -> RadixNode: - """Cache a token sequence, potentially extending existing pages""" - with self._lock: - if existing_pages: - total_cached_tokens = sum(p.token_count for p in existing_pages) - new_tokens = token_ids[total_cached_tokens:] - if new_tokens: - new_pages = self._allocate_pages(len(new_tokens)) - pages = existing_pages + new_pages - else: - pages = existing_pages - else: - pages = self._allocate_pages(len(token_ids)) - - return self._insert_sequence(token_ids, pages) - - def _allocate_pages(self, token_count: int) -> List[PageInfo]: - """Allocate pages needed for token sequence""" - pages_needed = (token_count + self.tokens_per_page - 1) // self.tokens_per_page - page_entries = self.page_pool.acquire_free_pages(pages_needed) - - if not page_entries: - self._evict_pages(pages_needed) - page_entries = self.page_pool.acquire_free_pages(pages_needed) - if not page_entries: - raise RuntimeError( - f"Failed to allocate {pages_needed} pages after eviction" - ) - - pages = [] - tokens_remaining = token_count - for entry in page_entries: - tokens_in_page = min(self.tokens_per_page, tokens_remaining) - pages.append( - PageInfo( - page=entry, token_offset=0, token_count=tokens_in_page, ref_count=1 - ) - ) - tokens_remaining -= tokens_in_page - - return pages - - def _insert_sequence( - self, token_ids: List[int], pages: List[PageInfo] - ) -> RadixNode: - """Insert a sequence into the radix tree""" - curr_node = self.root - remaining_tokens = token_ids - - while remaining_tokens: - first_token = remaining_tokens[0] - if first_token not in curr_node.children: - # Create new leaf node - new_node = RadixNode( - children={}, - parent=curr_node, - key=remaining_tokens, - pages=pages[len(token_ids) - len(remaining_tokens) :], - ref_count=1, - ) - curr_node.children[first_token] = new_node - return new_node - - child = curr_node.children[first_token] - match_len = self._get_match_len(child.key, remaining_tokens) - - if match_len < len(child.key): - # Split existing node - split_node = self._split_node(child, match_len) - if match_len < len(remaining_tokens): - # Create new node for remaining tokens - new_node = RadixNode( - children={}, - parent=split_node, - key=remaining_tokens[match_len:], - pages=pages[ - len(token_ids) - len(remaining_tokens) + match_len : - ], - ref_count=1, - ) - split_node.children[remaining_tokens[match_len]] = new_node - return new_node - return split_node - - remaining_tokens = remaining_tokens[match_len:] - curr_node = child - - return curr_node + old_node.key == [1, 2, 3, 4] + new_node = tree._split_node(old_node, 2) + + assert new_node.key == [1, 2] + assert old_node.key == [3, 4] + assert old_node.parent == new_node + ``` + """ + raise NotImplementedError() def _evict_pages(self, pages_needed: int) -> None: - """Evict pages using LRU strategy""" - # Collect all nodes - nodes = [] - stack = [self.root] - while stack: - node = stack.pop() - stack.extend(node.children.values()) - if not node.children: # leaf node - nodes.append(node) - - # Sort by access time - nodes.sort(key=lambda n: n.last_access_timestamp) - - pages_freed = 0 - for node in nodes: - if node.ref_count == 0: - freeable_pages = [p for p in node.pages if p.ref_count == 0] - self.page_pool.release_pages([p.page for p in freeable_pages]) - pages_freed += len(freeable_pages) - - # Remove node if all pages freed - if len(freeable_pages) == len(node.pages): - del node.parent.children[node.key[0]] - - if pages_freed >= pages_needed: - break - - def release_pages(self, pages: List[PageInfo]) -> None: - """Release references to pages""" - with self._lock: - for page in pages: - page.ref_count -= 1 + """Evict pages using LRU strategy until enough pages are free. + + Args: + pages_needed: Number of pages that need to be freed + + Example: + ```python + # Internal use when cache is full + # If we need 5 pages and cache is full: + tree._evict_pages(5) + + # After eviction, at least 5 pages should be available + pages = page_pool.acquire_free_pages(5) + assert pages is not None + ``` + """ + raise NotImplementedError() diff --git a/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py new file mode 100644 index 000000000..748888f0a --- /dev/null +++ b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py @@ -0,0 +1,77 @@ +import pytest +import logging +from shortfin_apps.llm.components.kvcache.page_pool import PagePool, PagePoolConfig +import shortfin as sf +import shortfin.host +import shortfin.array as sfnp +import shortfin.amdgpu + +logger = logging.getLogger(__name__) + + +@pytest.fixture( + params=[("cpu", sf.host.CPUSystemBuilder), ("gpu", sf.amdgpu.SystemBuilder)] +) +def setup_system(request): + system_type, builder_class = request.param + logger.info(f"=== Setting up {system_type.upper()} system ===") + sc = builder_class() + lsys = sc.create_system() + fiber = lsys.create_fiber() + devices = fiber.devices_dict.values() + yield system_type, lsys, devices + lsys.shutdown() + + +@pytest.fixture +def setup_pool(setup_system): + system_type, _, devices = setup_system + logger.info(f"Creating PagePool for {system_type.upper()} system") + pool = PagePool( + devices=devices, + config=PagePoolConfig( + alloc_page_count=256, + dtype=sfnp.float16, + paged_kv_block_size_elements=393216, + ), + ) + return system_type, pool + + +def test_page_acquisition(setup_pool): + system_type, pool = setup_pool + logger.info( + f"=== Running page acquisition test on {system_type.upper()} system ===" + ) + page0 = pool.acquire_free_pages(1) + assert page0 is not None, f"Failed to acquire a free page on {system_type} system" + logger.info(f"Successfully acquired page on {system_type.upper()} system") + + +def test_page_copy(setup_pool): + system_type, pool = setup_pool + logger.info(f"=== Running page copy test on {system_type.upper()} system ===") + (page0,) = pool.acquire_free_pages(1) + page1 = pool.copy_page(page0) + assert page1 is not None, f"Failed to copy a page on {system_type} system" + assert ( + page0 != page1 + ), f"Copied page should be different from original on {system_type} system" + logger.info(f"Successfully copied page on {system_type.upper()} system") + + +@pytest.fixture(autouse=True) +def setup_logging(): + """Set up logging format to include timestamp and level""" + logging.basicConfig( + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, + force=True, + ) + + +# Add more tests as needed + +if __name__ == "__main__": + pytest.main([__file__]) From 8fa4b921723992fab1116ac2c137923c44a061f1 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 21 Nov 2024 15:42:09 +0000 Subject: [PATCH 04/23] cleanup unused file --- .../llm/components/kvcache/attention_paging.py | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 shortfin/python/shortfin_apps/llm/components/kvcache/attention_paging.py diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/attention_paging.py b/shortfin/python/shortfin_apps/llm/components/kvcache/attention_paging.py deleted file mode 100644 index 3dee22ff6..000000000 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/attention_paging.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Sequence - -import logging -import math -import threading - -import shortfin as sf - -from .config_struct import ModelParams, human_size - -logger = logging.getLogger(__name__) From a0d1cb2e5477debb3e3e42ec0ec8111eabd2b91b Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 21 Nov 2024 15:42:49 +0000 Subject: [PATCH 05/23] add current work --- .../kvcache/base_attention_cache.py | 80 +++++++++++++++++++ .../llm/components/kvcache/page_pool.py | 3 +- .../llm/components/kvcache/radix_tree.py | 4 +- 3 files changed, 84 insertions(+), 3 deletions(-) create mode 100644 shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py new file mode 100644 index 000000000..4dde584a8 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -0,0 +1,80 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Base class for kv caches. +""" + +from typing import List +from attention_paging import PageInfo +import math + + + +class BasePagedAttentionCache: + """ + Manages lifecycle of pages (using PageInfo as handles). + + + Page States: + Caching - Page can be read by multiple threads + - Also maintains a reference count + Writing - Page is being modified by a single owner thread + + Transitions: + Caching -> Writing: When acquiring an unreferenced LRU leaf page for writing + Writing -> Caching: When writing is complete and page is released + + Thread Safety: + - Multiple readers allowed in ReadableCaching state + - Single writer exclusive access in Writing state + - Reference counting prevents eviction of in-use pages + """ + def __init__(self, page_pool, tokens_per_page): + self.page_pool = page_pool + self.tokens_per_page = tokens_per_page + + + def acquire_pages_for_tokens(self, tokens: List[int]) -> tuple[List[PageInfo], int)]: + """ + Given a list of tokens, return a list of pages and a start position to continue generation from. + + In the base implementation, this will just allocate all new pages, but in shared-kv implementations, we will fetch cached pages if applicable. + + The pages are returned in order. + + No token at idx < n_cached_token should be written to. TODO: consider enforcing this. + """ + pages_needed = math.ceil(len(tokens) / self.tokens_per_page) + pages = self.page_pool.acquire_free_pages(pages_needed) + + n_cached_tokens = 0 + + return pages, n_cached_tokens + + def publish_pages(self, tokens, pages) -> None: + """ + Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests. + + Associates the tokens with the pages, and mark them as done writing. + + It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)]. + """ + + pass # the base implementation doesn't cache unfinished requests. + + def release_pages(self, tokens, pages): + """ + Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction. + """ + # in the base implementation, the pages can be owned by 1 request max, so they can be instantly release + self.page_pool.release_pages(pages) + + + + + + diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py index 125d00f58..2991e85e3 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py @@ -24,7 +24,8 @@ class PageInfo: pool: PagePool token_offset: int # Offset within the page token_count: int # Number of tokens stored in this page - ref_count: int = 0 # Number of references to this page in the radix tree + writing: bool = False + read_ref_count: int = 0 # Number of threads that still need to read this page. When this reaches 0, page is eligible for release @dataclass diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py b/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py index d9483e749..76f9e7a32 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import List, Dict, Optional, Tuple, TypeVar, Generic from dataclasses import dataclass - +from page_pool import PagePool T = TypeVar("T") # Generic type for page data @@ -74,7 +74,7 @@ class RadixTree(Generic[T]): """ def __init__( - self, *, page_pool: Any, tokens_per_page: int, disable: bool = False + self, page_pool: PagePool, tokens_per_page: int ) -> None: """Initialize the radix tree. From e74fcce7c801b2476412b8b15a00c7b8a8e308dc Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 22 Nov 2024 09:55:25 +0000 Subject: [PATCH 06/23] might be ready for testing --- .../kvcache/base_attention_cache.py | 23 +++++++++---------- .../shortfin_apps/llm/components/service.py | 12 ++++++++-- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index 4dde584a8..e02d55851 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -13,12 +13,11 @@ import math - class BasePagedAttentionCache: """ Manages lifecycle of pages (using PageInfo as handles). - + Page States: Caching - Page can be read by multiple threads - Also maintains a reference count @@ -33,22 +32,28 @@ class BasePagedAttentionCache: - Single writer exclusive access in Writing state - Reference counting prevents eviction of in-use pages """ + def __init__(self, page_pool, tokens_per_page): self.page_pool = page_pool self.tokens_per_page = tokens_per_page - - def acquire_pages_for_tokens(self, tokens: List[int]) -> tuple[List[PageInfo], int)]: + def acquire_pages_for_tokens( + self, tokens: List[int], extra_token_slots: int = 1 + ) -> tuple[list[PageInfo], int]: """ Given a list of tokens, return a list of pages and a start position to continue generation from. + Parameters: + - tokens: all the known tokens for this generation request + - extra_token_slots: number of kvcache slots needed in addition to the ones needed to hold the given tokens. + In the base implementation, this will just allocate all new pages, but in shared-kv implementations, we will fetch cached pages if applicable. The pages are returned in order. No token at idx < n_cached_token should be written to. TODO: consider enforcing this. """ - pages_needed = math.ceil(len(tokens) / self.tokens_per_page) + pages_needed = math.ceil(len(tokens + extra_token_slots) / self.tokens_per_page) pages = self.page_pool.acquire_free_pages(pages_needed) n_cached_tokens = 0 @@ -64,7 +69,7 @@ def publish_pages(self, tokens, pages) -> None: It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)]. """ - pass # the base implementation doesn't cache unfinished requests. + pass # the base implementation doesn't cache unfinished requests. def release_pages(self, tokens, pages): """ @@ -72,9 +77,3 @@ def release_pages(self, tokens, pages): """ # in the base implementation, the pages can be owned by 1 request max, so they can be instantly release self.page_pool.release_pages(pages) - - - - - - diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index bcd08b756..efbcfc747 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -218,7 +218,11 @@ def board_prefills(self, cache: AttnPageCache): needed_pages = math.ceil( len(prefill_request.input_token_ids) / self.page_seq_stride ) - pages = cache.acquire_free_pages(needed_pages) + # allocate kv cache pages + pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( + prefill_request.input_token_ids, + extra_token_slots=0, # prefill needs no extra kvcache slots to write to + ) if pages is None: logger.debug("Cannot fulfill request for %d pages", needed_pages) continue @@ -254,7 +258,11 @@ def board_decodes(self, cache: AttnPageCache): / self.page_seq_stride ) if needed_pages > len(decode_request.locked_pages): - pages = cache.acquire_free_pages(needed_pages) + # allocate kv cache pages + pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( + decode_request.input_token_ids, + extra_token_slots=1, # need 1 extra slot to write result. + ) if pages is None: logger.debug( "Cannot fulfill decode request for %d pages", needed_pages From 405b2c2623b9e5e6dc12344ddad6c38818bd9ed7 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 22 Nov 2024 09:57:39 +0000 Subject: [PATCH 07/23] fix precommit formatting for my notes file --- .../shortfin_apps/llm/components/kvcache/radix_tree.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py b/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py index 76f9e7a32..aca2511b8 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py @@ -2,6 +2,7 @@ from typing import List, Dict, Optional, Tuple, TypeVar, Generic from dataclasses import dataclass from page_pool import PagePool + T = TypeVar("T") # Generic type for page data @@ -73,9 +74,7 @@ class RadixTree(Generic[T]): ``` """ - def __init__( - self, page_pool: PagePool, tokens_per_page: int - ) -> None: + def __init__(self, page_pool: PagePool, tokens_per_page: int) -> None: """Initialize the radix tree. Args: From 90870aa85c47d7abf9fb162eeb871654d7d1b5a8 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 22 Nov 2024 13:03:52 +0000 Subject: [PATCH 08/23] fix tests --- .../shortfin_apps/llm/components/kvcache/page_pool.py | 1 - .../tests/apps/llm/components/kvcache/page_pool_test.py | 7 ++++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py index 2991e85e3..3d106869c 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py @@ -82,7 +82,6 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig pool=self, token_offset=0, token_count=0, - ref_count=0, ) for i in range(self.config.alloc_page_count) ] diff --git a/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py index 748888f0a..b90f7daff 100644 --- a/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py @@ -10,7 +10,12 @@ @pytest.fixture( - params=[("cpu", sf.host.CPUSystemBuilder), ("gpu", sf.amdgpu.SystemBuilder)] + params=[ + pytest.param("cpu", sf.host.CPUSystemBuilder, marks=[]), + pytest.param( + "gpu", sf.amdgpu.SystemBuilder, marks=[pytest.mark.system("amdgpu")] + ), + ] ) def setup_system(request): system_type, builder_class = request.param From 409a3d815c051683072aa817486a46b98546848f Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 22 Nov 2024 14:25:35 +0000 Subject: [PATCH 09/23] fix cache test --- .../tests/apps/llm/components/kvcache/page_pool_test.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py index b90f7daff..95d35b341 100644 --- a/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py @@ -10,15 +10,12 @@ @pytest.fixture( - params=[ - pytest.param("cpu", sf.host.CPUSystemBuilder, marks=[]), - pytest.param( - "gpu", sf.amdgpu.SystemBuilder, marks=[pytest.mark.system("amdgpu")] - ), - ] + params=[("cpu", sf.host.CPUSystemBuilder), ("gpu", sf.amdgpu.SystemBuilder)] ) def setup_system(request): system_type, builder_class = request.param + if system_type == "gpu" and not "gpu" in pytest.config.getoption("--system"): + pytest.skip("Skipping GPU-specific test") logger.info(f"=== Setting up {system_type.upper()} system ===") sc = builder_class() lsys = sc.create_system() From 1ab5505fee20270880623de484fb7fc9cf318f28 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 22 Nov 2024 14:45:30 +0000 Subject: [PATCH 10/23] device construction fix on non-amdgpu --- .../llm/components/kvcache/page_pool_test.py | 32 +++++++----------- shortfin/tests/conftest.py | 33 +++++++++++++++++++ 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py index 95d35b341..f9e0f6e4c 100644 --- a/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py @@ -14,8 +14,6 @@ ) def setup_system(request): system_type, builder_class = request.param - if system_type == "gpu" and not "gpu" in pytest.config.getoption("--system"): - pytest.skip("Skipping GPU-specific test") logger.info(f"=== Setting up {system_type.upper()} system ===") sc = builder_class() lsys = sc.create_system() @@ -26,40 +24,34 @@ def setup_system(request): @pytest.fixture -def setup_pool(setup_system): - system_type, _, devices = setup_system - logger.info(f"Creating PagePool for {system_type.upper()} system") +def setup_pool(generic_device): pool = PagePool( - devices=devices, + devices=[generic_device], config=PagePoolConfig( alloc_page_count=256, dtype=sfnp.float16, paged_kv_block_size_elements=393216, ), ) - return system_type, pool + return pool def test_page_acquisition(setup_pool): - system_type, pool = setup_pool - logger.info( - f"=== Running page acquisition test on {system_type.upper()} system ===" - ) + pool = setup_pool + logger.info(f"=== Running page acquisition test on system ===") page0 = pool.acquire_free_pages(1) - assert page0 is not None, f"Failed to acquire a free page on {system_type} system" - logger.info(f"Successfully acquired page on {system_type.upper()} system") + assert page0 is not None, f"Failed to acquire a free page on system" + logger.info(f"Successfully acquired page on system") def test_page_copy(setup_pool): - system_type, pool = setup_pool - logger.info(f"=== Running page copy test on {system_type.upper()} system ===") + pool = setup_pool + logger.info(f"=== Running page copy test on system ===") (page0,) = pool.acquire_free_pages(1) page1 = pool.copy_page(page0) - assert page1 is not None, f"Failed to copy a page on {system_type} system" - assert ( - page0 != page1 - ), f"Copied page should be different from original on {system_type} system" - logger.info(f"Successfully copied page on {system_type.upper()} system") + assert page1 is not None, f"Failed to copy a page on system" + assert page0 != page1, f"Copied page should be different from original on system" + logger.info(f"Successfully copied page on system") @pytest.fixture(autouse=True) diff --git a/shortfin/tests/conftest.py b/shortfin/tests/conftest.py index 083698968..6a4adabfc 100644 --- a/shortfin/tests/conftest.py +++ b/shortfin/tests/conftest.py @@ -50,6 +50,17 @@ def pytest_runtest_setup(item): sf.SystemBuilder.default_system_type = system_type +# Dynamic Parameterization for lsys Fixture +def pytest_generate_tests(metafunc): + if "lsys" in metafunc.fixturenames: + system = metafunc.config.getoption("--system") + if system == "amdgpu": + params = ["cpu", "amdgpu"] + else: + params = [system] + metafunc.parametrize("lsys", params, indirect=True) + + # Keys that will be cleaned project wide prior to and after each test run. # Test code can freely modify these. CLEAN_ENV_KEYS = [ @@ -96,6 +107,28 @@ def kill(): kill() +@pytest.fixture +def generic_lsys(request): + system_type = request.param + if system_type == "cpu" or system_type == "hostcpu": + sc = sf.host.CPUSystemBuilder() + elif system_type == "amdgpu": + sc = sf.amdgpu.SystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() + + +@pytest.fixture +def generic_fiber(generic_lsys): + return generic_lsys.create_fiber() + + +@pytest.fixture +def generic_device(generic_fiber): + return generic_fiber.device(0) + + @pytest.fixture def cpu_lsys(): sc = sf.host.CPUSystemBuilder() From 5d79dbc354f88f80304d5cd60871e92f9e85e78f Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 22 Nov 2024 14:47:20 +0000 Subject: [PATCH 11/23] remove radix tree bcs not relevant to this pr --- .../llm/components/kvcache/radix_tree.py | 263 ------------------ 1 file changed, 263 deletions(-) delete mode 100644 shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py b/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py deleted file mode 100644 index aca2511b8..000000000 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/radix_tree.py +++ /dev/null @@ -1,263 +0,0 @@ -from __future__ import annotations -from typing import List, Dict, Optional, Tuple, TypeVar, Generic -from dataclasses import dataclass -from page_pool import PagePool - -T = TypeVar("T") # Generic type for page data - - -@dataclass -class RadixNode(Generic[T]): - """A node in the radix tree that tracks cached pages of data. - - Each node represents a sequence of tokens and maintains references to the pages - containing the cached data for those tokens. The node structure allows for - efficient prefix matching and sharing of cached data. - - Attributes: - children: Mapping of first token to child nodes - parent: Reference to parent node, None for root - key: Token sequence this node represents - pages: List of page data associated with this token sequence - last_access_timestamp: Unix timestamp of last access - ref_count: Number of active references to this node - - Example: - ```python - # Create a leaf node for token sequence [5, 2, 8] - node = RadixNode( - children={}, - parent=parent_node, - key=[5, 2, 8], - pages=[page1, page2], - ref_count=1 - ) - - # Access timestamp is automatically updated when node is accessed - assert node.last_access_timestamp > 0 - - # When done with node, decrement reference count - node.ref_count -= 1 - ``` - """ - - children: Dict[int, RadixNode[T]] - parent: Optional[RadixNode[T]] - key: List[int] - pages: List[T] - last_access_timestamp: int = 0 - ref_count: int = 0 - - -class RadixTree(Generic[T]): - """A radix tree implementation for caching token sequence data. - - The tree efficiently stores and retrieves cached data for token sequences, - enabling prefix sharing and fast lookups. It handles memory management through - reference counting and LRU eviction. - - Example: - ```python - # Initialize tree with a page pool and 16 tokens per page - tree = RadixTree(page_pool=my_pool, tokens_per_page=16) - - # Cache a sequence of tokens with their associated data - token_ids = [1, 5, 8, 2] - node = tree.cache_sequence(token_ids) - - # Later, find cached data for a prefix - pages, match_node = tree.match_prefix([1, 5, 8]) - assert len(pages) > 0 - - # When done with the cached data, release it - tree.release_pages(pages) - ``` - """ - - def __init__(self, page_pool: PagePool, tokens_per_page: int) -> None: - """Initialize the radix tree. - - Args: - page_pool: Pool that manages the underlying page allocations - tokens_per_page: Number of tokens that can be stored in each page - disable: If True, disables caching (useful for testing) - - Example: - ```python - tree = RadixTree( - page_pool=PagePool(...), - tokens_per_page=16, - disable=False - ) - ``` - """ - raise NotImplementedError() - - def reset(self) -> None: - """Reset the tree to initial empty state. - - Releases all cached pages and resets the tree to contain only a root node. - - Example: - ```python - tree = RadixTree(...) - - # Cache some sequences - tree.cache_sequence([1, 2, 3]) - tree.cache_sequence([4, 5, 6]) - - # Reset tree to clean state - tree.reset() - - # Tree is now empty except for root node - pages, _ = tree.match_prefix([1, 2, 3]) - assert len(pages) == 0 - ``` - """ - raise NotImplementedError() - - def match_prefix(self, token_ids: List[int]) -> Tuple[List[T], RadixNode[T]]: - """Find the longest matching prefix and return its cached pages. - - Args: - token_ids: Sequence of tokens to match against - - Returns: - Tuple containing: - - List of cached pages for the matching prefix - - The node containing the last matched token - - Example: - ```python - # Cache a sequence - tree.cache_sequence([1, 2, 3, 4, 5]) - - # Match a prefix - pages, node = tree.match_prefix([1, 2, 3]) - - # pages contains cached data for tokens [1, 2, 3] - assert len(pages) > 0 - - # node represents the position after [1, 2, 3] - assert node.key == [1, 2, 3] - - # Don't forget to release when done - tree.release_pages(pages) - ``` - """ - raise NotImplementedError() - - def cache_sequence( - self, token_ids: List[int], existing_pages: Optional[List[T]] = None - ) -> RadixNode[T]: - """Cache a token sequence, potentially extending existing cached pages. - - Args: - token_ids: Complete sequence of tokens to cache - existing_pages: Optional list of already cached pages to extend - - Returns: - Node containing the cached sequence - - Example: - ```python - # Cache initial sequence - node1 = tree.cache_sequence([1, 2, 3]) - - # Match prefix and extend with new tokens - pages, _ = tree.match_prefix([1, 2, 3]) - node2 = tree.cache_sequence([1, 2, 3, 4, 5], existing_pages=pages) - - # New node contains extended sequence - assert node2.key == [1, 2, 3, 4, 5] - - # Release pages when done - tree.release_pages(pages) - ``` - """ - raise NotImplementedError() - - def release_pages(self, pages: List[T]) -> None: - """Release references to cached pages. - - Decrements reference counts and potentially frees memory if counts reach zero. - - Args: - pages: List of pages to release - - Example: - ```python - # Get cached pages - pages, _ = tree.match_prefix([1, 2, 3]) - - # Use the pages... - - # Release when done - tree.release_pages(pages) - ``` - """ - raise NotImplementedError() - - def _get_match_len(self, key1: List[int], key2: List[int]) -> int: - """Return length of matching prefix between two keys. - - Args: - key1: First sequence of tokens - key2: Second sequence of tokens - - Returns: - Length of the matching prefix - - Example: - ```python - # Internal use for finding split points - length = tree._get_match_len([1, 2, 3, 4], [1, 2, 5, 6]) - assert length == 2 # Matches [1, 2] - ``` - """ - raise NotImplementedError() - - def _split_node(self, node: RadixNode[T], split_pos: int) -> RadixNode[T]: - """Split a node at the given position. - - Args: - node: Node to split - split_pos: Position in the node's key where split should occur - - Returns: - New intermediate node created by the split - - Example: - ```python - # Internal use during insertion - # If we have a node with key [1, 2, 3, 4] and need to - # insert [1, 2, 5, 6], we first split at position 2: - - old_node.key == [1, 2, 3, 4] - new_node = tree._split_node(old_node, 2) - - assert new_node.key == [1, 2] - assert old_node.key == [3, 4] - assert old_node.parent == new_node - ``` - """ - raise NotImplementedError() - - def _evict_pages(self, pages_needed: int) -> None: - """Evict pages using LRU strategy until enough pages are free. - - Args: - pages_needed: Number of pages that need to be freed - - Example: - ```python - # Internal use when cache is full - # If we need 5 pages and cache is full: - tree._evict_pages(5) - - # After eviction, at least 5 pages should be available - pages = page_pool.acquire_free_pages(5) - assert pages is not None - ``` - """ - raise NotImplementedError() From d6a6a2e5c1e3f971b9ca57761cb813bc2478f382 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 22 Nov 2024 14:56:02 +0000 Subject: [PATCH 12/23] clean up some stragglers --- .../shortfin_apps/llm/components/cache.py | 111 ------------------ .../llm/components/kvcache/page_pool.py | 2 +- .../shortfin_apps/llm/components/messages.py | 12 +- .../shortfin_apps/llm/components/service.py | 15 ++- .../tests/apps/llm/components/cache_test.py | 94 --------------- 5 files changed, 17 insertions(+), 217 deletions(-) delete mode 100644 shortfin/python/shortfin_apps/llm/components/cache.py delete mode 100644 shortfin/tests/apps/llm/components/cache_test.py diff --git a/shortfin/python/shortfin_apps/llm/components/cache.py b/shortfin/python/shortfin_apps/llm/components/cache.py deleted file mode 100644 index 12794498f..000000000 --- a/shortfin/python/shortfin_apps/llm/components/cache.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Sequence - -import logging -import math -import threading - -import shortfin as sf - -from .config_struct import ModelParams, human_size - -logger = logging.getLogger(__name__) - - -class AttnPageEntry: - __slots__ = [ - "cache", - "index", - "in_use", - ] - - def __init__(self, cache: "AttnPageCache", index: int): - self.cache = cache - self.index = index - self.in_use = False - - def __repr__(self): - return f"Block({self.index}, {'FREE' if not self.in_use else 'BUSY'})" - - -class AttnPageCache: - """Page table based attention cache. - - While internal to a model, the cache is organized with additional structure - per page, outside of the model, it is just a list of pages of a certain - element type and number of elements (all inner dims are flattened). - - One page table is allocated per device in a fiber. Currently, this is a - dense allocation with committed memory but in the future, we may just - allocate the address space and lazily populate it with committed memory. - - The cache is unique because usage of it can span fibers and concurrency - is implicitly managed at the block level (i.e. freshly acquired blocks - are assumed to be uninitialized and available immediately for use). - - It is initialized with a discrete list of fiberd devices from a fiber but - cache usage can be done from any fiber which includes those devices. - """ - - def __init__( - self, *, devices: Sequence[sf.ScopedDevice], model_params: ModelParams - ): - self._lock = threading.Lock() - self.devices = list(devices) - self.model_params = model_params - self.page_tables: list[sf.array.device_array] = [] - cache_params = model_params.paged_kv_cache - alloc_page_count = cache_params.device_block_count - - # Setup accounting structs. - self.attn_page_entries = [ - AttnPageEntry(self, i) for i in range(alloc_page_count) - ] - self.attn_page_free = list(self.attn_page_entries) - - # Initialize a page table on each device. - assert cache_params is not None, "Model does not have a paged kv cache" - page_table_shape = [ - alloc_page_count, - model_params.paged_kv_block_size_elements, - ] - for device in devices: - logging.info( - "Allocating page table (shape=%r, dtype=%r, size=%s) on %r", - page_table_shape, - model_params.attn_dtype, - human_size( - math.prod(page_table_shape) - * model_params.attn_dtype.dense_byte_count - ), - device, - ) - page_table = sf.array.device_array.for_device( - device, page_table_shape, model_params.attn_dtype - ) - self.page_tables.append(page_table) - - def acquire_free_pages(self, count: int) -> list[AttnPageEntry] | None: - with self._lock: - available = len(self.attn_page_free) - if count > available: - return None - return [self.attn_page_free.pop() for _ in range(count)] - - def release_pages(self, pages: list[AttnPageEntry]): - with self._lock: - self.attn_page_free.extend(pages) - - def __repr__(self): - # No need to lock for repr (list is internally synchronized). - free_pages = len(self.attn_page_free) - total_pages = len(self.attn_page_entries) - return ( - f"AttnPageCache({total_pages - free_pages}/{total_pages} pages in use: " - f"{100.0 * free_pages / total_pages}% free)" - ) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py index 3d106869c..8c9c3c0c3 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py @@ -151,7 +151,7 @@ def __repr__(self): free_pages = len(self.attn_page_free) total_pages = len(self.attn_page_entries) return ( - f"AttnPageCache({total_pages - free_pages}/{total_pages} pages in use: " + f"PagePool({total_pages - free_pages}/{total_pages} pages in use: " f"{100.0 * free_pages / total_pages}% free)" ) diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index fdcbeefc1..38549519c 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -9,7 +9,7 @@ import shortfin as sf import shortfin.array as sfnp -from .cache import AttnPageCache, AttnPageEntry +from .cache import BasePagedAttentionCache, AttnPageEntry class InferencePhase(Enum): @@ -41,7 +41,7 @@ def __init__(self, phase: InferencePhase, input_token_ids: list[int]): self.result_logits: sfnp.device_array | None = None # Cache pages that have been locked for this request. - self._cache: AttnPageCache | None = None + self._cache: BasePagedAttentionCache | None = None self.locked_pages: list[AttnPageEntry] | None = None def reset(self, phase: InferencePhase): @@ -66,16 +66,18 @@ def free_cache_pages(self): pages = self.locked_pages self._cache = None self.locked_pages = None - cache.release_pages(pages) + cache.release_pages(self.input_token_ids, pages) def lock_initial_cache_pages( - self, cache: AttnPageCache, pages: list[AttnPageEntry] + self, cache: BasePagedAttentionCache, pages: list[AttnPageEntry] ): assert not self._cache self._cache = cache self.locked_pages = pages - def lock_new_cache_pages(self, cache: AttnPageCache, pages: list[AttnPageEntry]): + def lock_new_cache_pages( + self, cache: BasePagedAttentionCache, pages: list[AttnPageEntry] + ): assert self._cache is cache self.locked_pages.extend(pages) diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index efbcfc747..5ab2aebd6 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -11,7 +11,7 @@ import shortfin as sf import shortfin.array as sfnp -from .cache import AttnPageCache +from .cache import BasePagedAttentionCache from .config_struct import ModelParams from .manager import SystemManager from .messages import InferenceExecRequest, InferencePhase, StrobeMessage @@ -54,7 +54,7 @@ def __init__( # Scope dependent objects. self.batcher = BatcherProcess(self) - self.page_cache = AttnPageCache( + self.page_cache = BasePagedAttentionCache( devices=self.main_fiber.devices_dict.values(), model_params=model_params ) @@ -200,7 +200,7 @@ def board_flights(self): self.pending_prefills.clear() logger.debug("Post boarding cache state: %r", cache) - def board_prefills(self, cache: AttnPageCache): + def board_prefills(self, cache: BasePagedAttentionCache): # Fill prefill flights. pending_prefills = self.pending_prefills if len(pending_prefills) == 0: @@ -209,7 +209,7 @@ def board_prefills(self, cache: AttnPageCache): self.service, InferencePhase.PREFILL, self.page_seq_stride, - cache.page_tables, + cache.page_pool.page_tables, ) for prefill_request in pending_prefills: assert prefill_request.phase == InferencePhase.PREFILL @@ -240,13 +240,16 @@ def board_prefills(self, cache: AttnPageCache): # And takeoff. exec_process.launch() - def board_decodes(self, cache: AttnPageCache): + def board_decodes(self, cache: BasePagedAttentionCache): # Fill decode flights. pending_decodes = self.pending_decodes if len(pending_decodes) == 0: return exec_process = InferenceExecutorProcess( - self.service, InferencePhase.DECODE, self.page_seq_stride, cache.page_tables + self.service, + InferencePhase.DECODE, + self.page_seq_stride, + cache.page_pool.page_tables, ) for decode_request in pending_decodes: assert decode_request.phase == InferencePhase.DECODE diff --git a/shortfin/tests/apps/llm/components/cache_test.py b/shortfin/tests/apps/llm/components/cache_test.py deleted file mode 100644 index 169d082b1..000000000 --- a/shortfin/tests/apps/llm/components/cache_test.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -""" -Tests for llm kvcache component. -""" - -import pytest -import time -import tempfile -import shortfin as sf -from _shortfin import lib as sfl -from shortfin_apps.llm.components import cache -from shortfin_apps.llm.components import config_struct -import json -from pathlib import Path - - -@pytest.fixture -def lsys(): - sc = sfl.local.host.CPUSystemBuilder() - ls = sc.create_system() - yield ls - ls.shutdown() - - -@pytest.fixture -def fiber(lsys): - # TODO: Should adopt the main thread. - worker = lsys.create_worker("main") - return lsys.create_fiber(worker) - - -@pytest.fixture -def device(fiber): - return fiber.device(0) - - -@pytest.fixture -def model_params(): - model_params = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 2048, - "attn_head_count": 32, - "attn_head_dim": 100, - "prefill_batch_sizes": [4], - "decode_batch_sizes": [4], - "transformer_block_count": 26, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, - } - - # Create a temporary file to store the JSON - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as tmp_file: - json.dump(model_params, tmp_file, indent=4) - tmp_path = Path(tmp_file.name) - - try: - # Load the JSON using config_struct - model_params = config_struct.ModelParams.load_json(tmp_path) - yield model_params - finally: - tmp_path.unlink - - -@pytest.fixture -def cache_fixture(fiber, model_params) -> cache.AttnPageCache: - # Create and return the cache object - return cache.AttnPageCache( - devices=fiber.devices_dict.values(), model_params=model_params - ) - - -@pytest.mark.parametrize("n_allocated", [1, 16, 255]) -def test_alloc( - cache_fixture: cache.AttnPageCache, - n_allocated, - model_params: config_struct.ModelParams, -): - alloc_page_count = cache_fixture.page_tables[0].shape[0] - - assert alloc_page_count == model_params.paged_kv_cache.device_block_count - - pages = cache_fixture.acquire_free_pages(n_allocated) - last_page = alloc_page_count - 1 - expected_indices = range(last_page, last_page - n_allocated, -1) - for p, expected_ix in zip(pages, expected_indices): - assert p.index == expected_ix - assert p.index > 0 From 81d01317ccfcbbbeba42159e7dbd37188beaaf82 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 22 Nov 2024 15:07:08 +0000 Subject: [PATCH 13/23] fix pytest scope --- shortfin/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shortfin/tests/conftest.py b/shortfin/tests/conftest.py index 6a4adabfc..044147bae 100644 --- a/shortfin/tests/conftest.py +++ b/shortfin/tests/conftest.py @@ -107,7 +107,7 @@ def kill(): kill() -@pytest.fixture +@pytest.fixture(scope="session") def generic_lsys(request): system_type = request.param if system_type == "cpu" or system_type == "hostcpu": From 49b10bb5ea1e1e807c4cdbd28634b93abcf7c03c Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 22 Nov 2024 15:20:00 +0000 Subject: [PATCH 14/23] fix tests --- .../apps/llm/components/kvcache/page_pool_test.py | 14 -------------- shortfin/tests/conftest.py | 4 ++-- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py index f9e0f6e4c..a1ec00c07 100644 --- a/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py @@ -9,20 +9,6 @@ logger = logging.getLogger(__name__) -@pytest.fixture( - params=[("cpu", sf.host.CPUSystemBuilder), ("gpu", sf.amdgpu.SystemBuilder)] -) -def setup_system(request): - system_type, builder_class = request.param - logger.info(f"=== Setting up {system_type.upper()} system ===") - sc = builder_class() - lsys = sc.create_system() - fiber = lsys.create_fiber() - devices = fiber.devices_dict.values() - yield system_type, lsys, devices - lsys.shutdown() - - @pytest.fixture def setup_pool(generic_device): pool = PagePool( diff --git a/shortfin/tests/conftest.py b/shortfin/tests/conftest.py index 044147bae..b16d5a3c9 100644 --- a/shortfin/tests/conftest.py +++ b/shortfin/tests/conftest.py @@ -52,13 +52,13 @@ def pytest_runtest_setup(item): # Dynamic Parameterization for lsys Fixture def pytest_generate_tests(metafunc): - if "lsys" in metafunc.fixturenames: + if "generic_lsys" in metafunc.fixturenames: system = metafunc.config.getoption("--system") if system == "amdgpu": params = ["cpu", "amdgpu"] else: params = [system] - metafunc.parametrize("lsys", params, indirect=True) + metafunc.parametrize("generic_lsys", params, indirect=True) # Keys that will be cleaned project wide prior to and after each test run. From 5c20685e9cffb4edb3d4cd3b7b949ee284579a6e Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 22 Nov 2024 15:33:49 +0000 Subject: [PATCH 15/23] replace some more references --- .../python/shortfin_apps/llm/components/messages.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index 38549519c..fb20be540 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -9,7 +9,11 @@ import shortfin as sf import shortfin.array as sfnp -from .cache import BasePagedAttentionCache, AttnPageEntry +from .kvcache.base_attention_cache import BasePagedAttentionCache +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .kvcache.page_pool import PageInfo class InferencePhase(Enum): @@ -42,7 +46,7 @@ def __init__(self, phase: InferencePhase, input_token_ids: list[int]): # Cache pages that have been locked for this request. self._cache: BasePagedAttentionCache | None = None - self.locked_pages: list[AttnPageEntry] | None = None + self.locked_pages: list[PageInfo] | None = None def reset(self, phase: InferencePhase): """Resets all per request state in preparation for an subsequent execution.""" @@ -69,14 +73,14 @@ def free_cache_pages(self): cache.release_pages(self.input_token_ids, pages) def lock_initial_cache_pages( - self, cache: BasePagedAttentionCache, pages: list[AttnPageEntry] + self, cache: BasePagedAttentionCache, pages: list[PageInfo] ): assert not self._cache self._cache = cache self.locked_pages = pages def lock_new_cache_pages( - self, cache: BasePagedAttentionCache, pages: list[AttnPageEntry] + self, cache: BasePagedAttentionCache, pages: list[PageInfo] ): assert self._cache is cache self.locked_pages.extend(pages) From 6233957e1229a7e427207e03bdb2d059982bf072 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 22 Nov 2024 16:32:18 +0000 Subject: [PATCH 16/23] various changes for compatibility with new PagePool --- .../llm/components/kvcache/base_attention_cache.py | 4 ++-- .../llm/components/kvcache/page_pool.py | 8 ++++---- .../shortfin_apps/llm/components/messages.py | 5 +---- .../python/shortfin_apps/llm/components/service.py | 14 ++++++++++++-- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index e02d55851..7b9f38145 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -9,7 +9,7 @@ """ from typing import List -from attention_paging import PageInfo +from .page_pool import PageInfo import math @@ -53,7 +53,7 @@ def acquire_pages_for_tokens( No token at idx < n_cached_token should be written to. TODO: consider enforcing this. """ - pages_needed = math.ceil(len(tokens + extra_token_slots) / self.tokens_per_page) + pages_needed = math.ceil(len(tokens) + extra_token_slots / self.tokens_per_page) pages = self.page_pool.acquire_free_pages(pages_needed) n_cached_tokens = 0 diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py index 8c9c3c0c3..1686370c0 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py @@ -20,7 +20,7 @@ class PageInfo: Page index with some metadata about its contents. """ - page_index: int + index: int pool: PagePool token_offset: int # Offset within the page token_count: int # Number of tokens stored in this page @@ -78,7 +78,7 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig # Setup accounting structs. self.attn_page_entries = [ PageInfo( - page_index=i, + index=i, pool=self, token_offset=0, token_count=0, @@ -136,8 +136,8 @@ def copy_page(self, src_page: PageInfo) -> PageInfo: # Copy the data on each device for page_table in self.page_tables: # View of source and destination pages - src_view = page_table.view(src_page.page_index) - dst_view = page_table.view(dst_page.page_index) + src_view = page_table.view(src_page.index) + dst_view = page_table.view(dst_page.index) # Copy the data dst_view.copy_from(src_view) diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index fb20be540..c3e6fe34b 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -10,10 +10,7 @@ import shortfin.array as sfnp from .kvcache.base_attention_cache import BasePagedAttentionCache -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from .kvcache.page_pool import PageInfo +from .kvcache.page_pool import PageInfo class InferencePhase(Enum): diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 5ab2aebd6..8d3cc1424 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -11,7 +11,8 @@ import shortfin as sf import shortfin.array as sfnp -from .cache import BasePagedAttentionCache +from .kvcache.base_attention_cache import BasePagedAttentionCache +from .kvcache.page_pool import PagePoolConfig, PagePool from .config_struct import ModelParams from .manager import SystemManager from .messages import InferenceExecRequest, InferencePhase, StrobeMessage @@ -54,8 +55,17 @@ def __init__( # Scope dependent objects. self.batcher = BatcherProcess(self) + page_pool_config = PagePoolConfig( + dtype=model_params.attn_dtype, + alloc_page_count=model_params.paged_kv_cache.device_block_count, + paged_kv_block_size_elements=model_params.paged_kv_block_size_elements, + ) + page_pool = PagePool( + devices=self.main_fiber.devices_dict.values(), config=page_pool_config + ) self.page_cache = BasePagedAttentionCache( - devices=self.main_fiber.devices_dict.values(), model_params=model_params + page_pool=page_pool, + tokens_per_page=model_params.paged_kv_cache.block_seq_stride, ) self.program_isolation = PROG_ISOLATIONS[program_isolation] From f8856c307ee4532e594c2dd1c11653af554bc92d Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Sun, 24 Nov 2024 05:30:18 +0000 Subject: [PATCH 17/23] add tokenizers as nogil test dependency --- shortfin/requirements-tests-nogil.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/shortfin/requirements-tests-nogil.txt b/shortfin/requirements-tests-nogil.txt index 1769467ab..33e27d57a 100644 --- a/shortfin/requirements-tests-nogil.txt +++ b/shortfin/requirements-tests-nogil.txt @@ -1,3 +1,4 @@ pytest requests uvicorn +tokenizers From 5eb2ef303077c1157e028d5f4130a5cbf83add56 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Sun, 24 Nov 2024 05:38:45 +0000 Subject: [PATCH 18/23] Revert "add tokenizers as nogil test dependency" This reverts commit d1cdf649b157d28ef7ebfe053b6c65d2c592fbd7. --- shortfin/requirements-tests-nogil.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/shortfin/requirements-tests-nogil.txt b/shortfin/requirements-tests-nogil.txt index 33e27d57a..1769467ab 100644 --- a/shortfin/requirements-tests-nogil.txt +++ b/shortfin/requirements-tests-nogil.txt @@ -1,4 +1,3 @@ pytest requests uvicorn -tokenizers From 33240628a12277f5d93c0782bfeda8b18b65bc4d Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Sun, 24 Nov 2024 06:12:04 +0000 Subject: [PATCH 19/23] skip when dependencies not found in a way compatible with shortfin --- shortfin/python/shortfin_apps/llm/_deps.py | 23 ++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/_deps.py b/shortfin/python/shortfin_apps/llm/_deps.py index 7123d011e..5b18b6558 100644 --- a/shortfin/python/shortfin_apps/llm/_deps.py +++ b/shortfin/python/shortfin_apps/llm/_deps.py @@ -5,13 +5,20 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from shortfin.support.deps import ShortfinDepNotFoundError +import sys -try: - import tokenizers -except ModuleNotFoundError as e: - raise ShortfinDepNotFoundError(__name__, "tokenizers") from e +shortfin_llm_deps = [ + "tokenizers", + "dataclasses_json", +] -try: - import dataclasses_json -except ModuleNotFoundError as e: - raise ShortfinDepNotFoundError(__name__, "dataclasses-json") from e +for dep in shortfin_llm_deps: + try: + __import__(dep) + except ModuleNotFoundError as e: + if "pytest" in sys.modules: + import pytest + + pytest.skip(f"Shortfin LLM dependency not available: {dep}") + else: + raise ShortfinDepNotFoundError(__name__, dep) from e From 022d845dbc45490c57126645f022ae8e699a830a Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Sun, 24 Nov 2024 06:21:44 +0000 Subject: [PATCH 20/23] add allow_module_level and apply to sd too --- .../ci_linux_x64_nogil-libshortfin.yml | 2 +- shortfin/python/shortfin_apps/llm/_deps.py | 8 ++++--- shortfin/python/shortfin_apps/sd/_deps.py | 24 +++++++++---------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml index 0e0e1db2a..5a4065702 100644 --- a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml @@ -98,6 +98,6 @@ jobs: - name: Run shortfin Python tests (full) working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/llm/components/cache_test.py --ignore=tests/apps/sd + pytest -s --ignore=tests/examples/fastapi_test.py # TODO: Enable further tests and switch to # pytest -s diff --git a/shortfin/python/shortfin_apps/llm/_deps.py b/shortfin/python/shortfin_apps/llm/_deps.py index 5b18b6558..ba7969564 100644 --- a/shortfin/python/shortfin_apps/llm/_deps.py +++ b/shortfin/python/shortfin_apps/llm/_deps.py @@ -7,18 +7,20 @@ from shortfin.support.deps import ShortfinDepNotFoundError import sys -shortfin_llm_deps = [ +deps = [ "tokenizers", "dataclasses_json", ] -for dep in shortfin_llm_deps: +for dep in deps: try: __import__(dep) except ModuleNotFoundError as e: if "pytest" in sys.modules: import pytest - pytest.skip(f"Shortfin LLM dependency not available: {dep}") + pytest.skip( + f"Shortfin LLM dependency not available: {dep}", allow_module_level=True + ) else: raise ShortfinDepNotFoundError(__name__, dep) from e diff --git a/shortfin/python/shortfin_apps/sd/_deps.py b/shortfin/python/shortfin_apps/sd/_deps.py index 92bd089ec..45544de2d 100644 --- a/shortfin/python/shortfin_apps/sd/_deps.py +++ b/shortfin/python/shortfin_apps/sd/_deps.py @@ -6,17 +6,17 @@ from shortfin.support.deps import ShortfinDepNotFoundError -try: - import transformers -except ModuleNotFoundError as e: - raise ShortfinDepNotFoundError(__name__, "transformers") from e +shortfin_llm_deps = ["tokenizers", "dataclasses_json", "transformers"] -try: - import tokenizers -except ModuleNotFoundError as e: - raise ShortfinDepNotFoundError(__name__, "tokenizers") from e +for dep in deps: + try: + __import__(dep) + except ModuleNotFoundError as e: + if "pytest" in sys.modules: + import pytest -try: - import dataclasses_json -except ModuleNotFoundError as e: - raise ShortfinDepNotFoundError(__name__, "dataclasses-json") from e + pytest.skip( + f"Shortfin LLM dependency not available: {dep}", allow_module_level=True + ) + else: + raise ShortfinDepNotFoundError(__name__, dep) from e From 8298ddabb045f944964f95661d1e955c020dd555 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Sun, 24 Nov 2024 06:26:45 +0000 Subject: [PATCH 21/23] actually let's not touch sd for now --- .../ci_linux_x64_nogil-libshortfin.yml | 2 +- shortfin/python/shortfin_apps/sd/_deps.py | 24 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml index 5a4065702..550366e1b 100644 --- a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml @@ -98,6 +98,6 @@ jobs: - name: Run shortfin Python tests (full) working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - pytest -s --ignore=tests/examples/fastapi_test.py + pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/sd # TODO: Enable further tests and switch to # pytest -s diff --git a/shortfin/python/shortfin_apps/sd/_deps.py b/shortfin/python/shortfin_apps/sd/_deps.py index 45544de2d..92bd089ec 100644 --- a/shortfin/python/shortfin_apps/sd/_deps.py +++ b/shortfin/python/shortfin_apps/sd/_deps.py @@ -6,17 +6,17 @@ from shortfin.support.deps import ShortfinDepNotFoundError -shortfin_llm_deps = ["tokenizers", "dataclasses_json", "transformers"] +try: + import transformers +except ModuleNotFoundError as e: + raise ShortfinDepNotFoundError(__name__, "transformers") from e -for dep in deps: - try: - __import__(dep) - except ModuleNotFoundError as e: - if "pytest" in sys.modules: - import pytest +try: + import tokenizers +except ModuleNotFoundError as e: + raise ShortfinDepNotFoundError(__name__, "tokenizers") from e - pytest.skip( - f"Shortfin LLM dependency not available: {dep}", allow_module_level=True - ) - else: - raise ShortfinDepNotFoundError(__name__, dep) from e +try: + import dataclasses_json +except ModuleNotFoundError as e: + raise ShortfinDepNotFoundError(__name__, "dataclasses-json") from e From a8d59cc0d34a4fa180dd41fa11d4d409e5c1d787 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Sun, 24 Nov 2024 06:39:40 +0000 Subject: [PATCH 22/23] better error message upon _deps skip --- shortfin/python/shortfin_apps/llm/_deps.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/shortfin/python/shortfin_apps/llm/_deps.py b/shortfin/python/shortfin_apps/llm/_deps.py index ba7969564..fb8ca8176 100644 --- a/shortfin/python/shortfin_apps/llm/_deps.py +++ b/shortfin/python/shortfin_apps/llm/_deps.py @@ -20,7 +20,8 @@ import pytest pytest.skip( - f"Shortfin LLM dependency not available: {dep}", allow_module_level=True + f"A test imports shortfin_apps.llm; skipping due to unavailable Shortfin LLM dependency: {dep}", + allow_module_level=True, ) else: raise ShortfinDepNotFoundError(__name__, dep) from e From 1c08e34fbacf89f64d524ebca197385b54c68c81 Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 25 Nov 2024 16:07:57 -0800 Subject: [PATCH 23/23] missed problem with math --- .../llm/components/kvcache/base_attention_cache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index 7b9f38145..0007000bc 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -53,7 +53,8 @@ def acquire_pages_for_tokens( No token at idx < n_cached_token should be written to. TODO: consider enforcing this. """ - pages_needed = math.ceil(len(tokens) + extra_token_slots / self.tokens_per_page) + token_count = len(tokens) + pages_needed = math.ceil(token_count / self.tokens_per_page) pages = self.page_pool.acquire_free_pages(pages_needed) n_cached_tokens = 0