diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index e02d55851..7b9f38145 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -9,7 +9,7 @@ """ from typing import List -from attention_paging import PageInfo +from .page_pool import PageInfo import math @@ -53,7 +53,7 @@ def acquire_pages_for_tokens( No token at idx < n_cached_token should be written to. TODO: consider enforcing this. """ - pages_needed = math.ceil(len(tokens + extra_token_slots) / self.tokens_per_page) + pages_needed = math.ceil(len(tokens) + extra_token_slots / self.tokens_per_page) pages = self.page_pool.acquire_free_pages(pages_needed) n_cached_tokens = 0 diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py index 8c9c3c0c3..1686370c0 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py @@ -20,7 +20,7 @@ class PageInfo: Page index with some metadata about its contents. """ - page_index: int + index: int pool: PagePool token_offset: int # Offset within the page token_count: int # Number of tokens stored in this page @@ -78,7 +78,7 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig # Setup accounting structs. self.attn_page_entries = [ PageInfo( - page_index=i, + index=i, pool=self, token_offset=0, token_count=0, @@ -136,8 +136,8 @@ def copy_page(self, src_page: PageInfo) -> PageInfo: # Copy the data on each device for page_table in self.page_tables: # View of source and destination pages - src_view = page_table.view(src_page.page_index) - dst_view = page_table.view(dst_page.page_index) + src_view = page_table.view(src_page.index) + dst_view = page_table.view(dst_page.index) # Copy the data dst_view.copy_from(src_view) diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index fb20be540..c3e6fe34b 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -10,10 +10,7 @@ import shortfin.array as sfnp from .kvcache.base_attention_cache import BasePagedAttentionCache -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from .kvcache.page_pool import PageInfo +from .kvcache.page_pool import PageInfo class InferencePhase(Enum): diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 5ab2aebd6..8d3cc1424 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -11,7 +11,8 @@ import shortfin as sf import shortfin.array as sfnp -from .cache import BasePagedAttentionCache +from .kvcache.base_attention_cache import BasePagedAttentionCache +from .kvcache.page_pool import PagePoolConfig, PagePool from .config_struct import ModelParams from .manager import SystemManager from .messages import InferenceExecRequest, InferencePhase, StrobeMessage @@ -54,8 +55,17 @@ def __init__( # Scope dependent objects. self.batcher = BatcherProcess(self) + page_pool_config = PagePoolConfig( + dtype=model_params.attn_dtype, + alloc_page_count=model_params.paged_kv_cache.device_block_count, + paged_kv_block_size_elements=model_params.paged_kv_block_size_elements, + ) + page_pool = PagePool( + devices=self.main_fiber.devices_dict.values(), config=page_pool_config + ) self.page_cache = BasePagedAttentionCache( - devices=self.main_fiber.devices_dict.values(), model_params=model_params + page_pool=page_pool, + tokens_per_page=model_params.paged_kv_cache.block_seq_stride, ) self.program_isolation = PROG_ISOLATIONS[program_isolation]