Skip to content

Commit

Permalink
various changes for compatibility with new PagePool
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Nov 22, 2024
1 parent 861507c commit 0eaae82
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

from typing import List
from attention_paging import PageInfo
from .page_pool import PageInfo
import math


Expand Down Expand Up @@ -53,7 +53,7 @@ def acquire_pages_for_tokens(
No token at idx < n_cached_token should be written to. TODO: consider enforcing this.
"""
pages_needed = math.ceil(len(tokens + extra_token_slots) / self.tokens_per_page)
pages_needed = math.ceil(len(tokens) + extra_token_slots / self.tokens_per_page)
pages = self.page_pool.acquire_free_pages(pages_needed)

n_cached_tokens = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class PageInfo:
Page index with some metadata about its contents.
"""

page_index: int
index: int
pool: PagePool
token_offset: int # Offset within the page
token_count: int # Number of tokens stored in this page
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig
# Setup accounting structs.
self.attn_page_entries = [
PageInfo(
page_index=i,
index=i,
pool=self,
token_offset=0,
token_count=0,
Expand Down Expand Up @@ -136,8 +136,8 @@ def copy_page(self, src_page: PageInfo) -> PageInfo:
# Copy the data on each device
for page_table in self.page_tables:
# View of source and destination pages
src_view = page_table.view(src_page.page_index)
dst_view = page_table.view(dst_page.page_index)
src_view = page_table.view(src_page.index)
dst_view = page_table.view(dst_page.index)
# Copy the data
dst_view.copy_from(src_view)

Expand Down
5 changes: 1 addition & 4 deletions shortfin/python/shortfin_apps/llm/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
import shortfin.array as sfnp

from .kvcache.base_attention_cache import BasePagedAttentionCache
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .kvcache.page_pool import PageInfo
from .kvcache.page_pool import PageInfo


class InferencePhase(Enum):
Expand Down
14 changes: 12 additions & 2 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import shortfin as sf
import shortfin.array as sfnp

from .cache import BasePagedAttentionCache
from .kvcache.base_attention_cache import BasePagedAttentionCache
from .kvcache.page_pool import PagePoolConfig, PagePool
from .config_struct import ModelParams
from .manager import SystemManager
from .messages import InferenceExecRequest, InferencePhase, StrobeMessage
Expand Down Expand Up @@ -54,8 +55,17 @@ def __init__(

# Scope dependent objects.
self.batcher = BatcherProcess(self)
page_pool_config = PagePoolConfig(
dtype=model_params.attn_dtype,
alloc_page_count=model_params.paged_kv_cache.device_block_count,
paged_kv_block_size_elements=model_params.paged_kv_block_size_elements,
)
page_pool = PagePool(
devices=self.main_fiber.devices_dict.values(), config=page_pool_config
)
self.page_cache = BasePagedAttentionCache(
devices=self.main_fiber.devices_dict.values(), model_params=model_params
page_pool=page_pool,
tokens_per_page=model_params.paged_kv_cache.block_seq_stride,
)

self.program_isolation = PROG_ISOLATIONS[program_isolation]
Expand Down

0 comments on commit 0eaae82

Please sign in to comment.