Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace AttnPagedCache with BasePagedAttentionCache #565

Merged
merged 23 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci_linux_x64_nogil-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ jobs:
- name: Run shortfin Python tests (full)
working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/llm/components/cache_test.py --ignore=tests/apps/sd
pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/sd
# TODO: Enable further tests and switch to
# pytest -s
26 changes: 18 additions & 8 deletions shortfin/python/shortfin_apps/llm/_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,23 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from shortfin.support.deps import ShortfinDepNotFoundError
import sys

try:
import tokenizers
except ModuleNotFoundError as e:
raise ShortfinDepNotFoundError(__name__, "tokenizers") from e
deps = [
"tokenizers",
"dataclasses_json",
]

try:
import dataclasses_json
except ModuleNotFoundError as e:
raise ShortfinDepNotFoundError(__name__, "dataclasses-json") from e
for dep in deps:
try:
__import__(dep)
except ModuleNotFoundError as e:
if "pytest" in sys.modules:
import pytest

pytest.skip(
f"A test imports shortfin_apps.llm; skipping due to unavailable Shortfin LLM dependency: {dep}",
allow_module_level=True,
)
else:
raise ShortfinDepNotFoundError(__name__, dep) from e
111 changes: 0 additions & 111 deletions shortfin/python/shortfin_apps/llm/components/cache.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
Base class for kv caches.
"""

from typing import List
from .page_pool import PageInfo
import math


class BasePagedAttentionCache:
"""
Manages lifecycle of pages (using PageInfo as handles).


Page States:
Caching - Page can be read by multiple threads
- Also maintains a reference count
Writing - Page is being modified by a single owner thread

Transitions:
Caching -> Writing: When acquiring an unreferenced LRU leaf page for writing
Writing -> Caching: When writing is complete and page is released

Thread Safety:
- Multiple readers allowed in ReadableCaching state
- Single writer exclusive access in Writing state
- Reference counting prevents eviction of in-use pages
"""

def __init__(self, page_pool, tokens_per_page):
self.page_pool = page_pool
self.tokens_per_page = tokens_per_page

def acquire_pages_for_tokens(
self, tokens: List[int], extra_token_slots: int = 1
) -> tuple[list[PageInfo], int]:
"""
Given a list of tokens, return a list of pages and a start position to continue generation from.

Parameters:
- tokens: all the known tokens for this generation request
- extra_token_slots: number of kvcache slots needed in addition to the ones needed to hold the given tokens.

In the base implementation, this will just allocate all new pages, but in shared-kv implementations, we will fetch cached pages if applicable.

The pages are returned in order.

No token at idx < n_cached_token should be written to. TODO: consider enforcing this.
"""
token_count = len(tokens)
pages_needed = math.ceil(token_count / self.tokens_per_page)
pages = self.page_pool.acquire_free_pages(pages_needed)

n_cached_tokens = 0

return pages, n_cached_tokens

def publish_pages(self, tokens, pages) -> None:
"""
Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests.

Associates the tokens with the pages, and mark them as done writing.

It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)].
"""

pass # the base implementation doesn't cache unfinished requests.

def release_pages(self, tokens, pages):
stbaione marked this conversation as resolved.
Show resolved Hide resolved
"""
Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction.
"""
# in the base implementation, the pages can be owned by 1 request max, so they can be instantly release
self.page_pool.release_pages(pages)
159 changes: 159 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from __future__ import annotations
from typing import List, Tuple, Optional, Sequence
import threading
import logging
import shortfin as sf
import shortfin.array as sfnp
from dataclasses import dataclass

from ..config_struct import human_size
import math

import time

logger = logging.getLogger(__name__)


@dataclass
class PageInfo:
"""
Page index with some metadata about its contents.
"""

index: int
pool: PagePool
token_offset: int # Offset within the page
token_count: int # Number of tokens stored in this page
writing: bool = False
read_ref_count: int = 0 # Number of threads that still need to read this page. When this reaches 0, page is eligible for release


@dataclass
class PagePoolConfig:
"""
Hyperparameters for the page pool.
"""

dtype: sf.dtype
alloc_page_count: int

paged_kv_block_size_elements: int # size of a single page as # of elements
# (e.g. one configuration for llama3.1 8b hax 32x2x16x8x128=1048576 elements where:
# 32: number of transformer blocks
# 2: one for k + one for v
# 16: tokens per page
# 8: head count (32 heads, but every 4 heads share the same kv buffer)
# 128: hidden dimension


class PagePool:
"""Page table based attention cache.

While internal to a model, the cache is organized with additional structure
per page, outside of the model, it is just a list of pages of a certain
element type and number of elements (all inner dims are flattened).

One page table is allocated per device in a fiber. Currently, this is a
dense allocation with committed memory but in the future, we may just
allocate the address space and lazily populate it with committed memory.

The cache is unique because usage of it can span fibers and concurrency
is implicitly managed at the block level (i.e. freshly acquired blocks
are assumed to be uninitialized and available immediately for use).

It is initialized with a discrete list of fiberd devices from a fiber but
cache usage can be done from any fiber which includes those devices.

In addition to supporting paged attention standalone, this also serves
as the array / buffer allocation layer for radix attention described in
`radix_tree.py`.
"""

def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig):
self._lock = threading.Lock()
self.devices = list(devices)
self.config = config
self.page_tables: list[sf.array.device_array] = []

# Setup accounting structs.
self.attn_page_entries = [
PageInfo(
index=i,
pool=self,
token_offset=0,
token_count=0,
)
for i in range(self.config.alloc_page_count)
]

self.attn_page_free = list(self.attn_page_entries)

# Initialize a page table on each device.
page_table_shape = [
self.config.alloc_page_count,
self.config.paged_kv_block_size_elements,
]
for device in devices:
logging.info(
"Allocating page table (shape=%r, dtype=%r, size=%s) on %r",
page_table_shape,
self.config.dtype,
human_size(config.dtype.compute_dense_nd_size(page_table_shape)),
device,
)
page_table = sf.array.device_array.for_device(
device, page_table_shape, self.config.dtype
)
self.page_tables.append(page_table)

def acquire_free_pages(self, count: int) -> list[PageInfo] | None:
with self._lock:
available = len(self.attn_page_free)
if count > available:
return None
return [self.attn_page_free.pop() for _ in range(count)]
stbaione marked this conversation as resolved.
Show resolved Hide resolved

def release_pages(self, pages: list[PageInfo]):
stbaione marked this conversation as resolved.
Show resolved Hide resolved
with self._lock:
self.attn_page_free.extend(pages)

def copy_page(self, src_page: PageInfo) -> PageInfo:
"""
Copy a page's contents to a new page.

Args:
src_page: Source page to copy from
token_count: Optional number of tokens to copy. If None, copies all tokens.

Returns:
New PageInfo containing the copied data
"""
# Allocate new page
(dst_page,) = self.acquire_free_pages(1)

# fill src page with data

# Copy the data on each device
for page_table in self.page_tables:
# View of source and destination pages
src_view = page_table.view(src_page.index)
dst_view = page_table.view(dst_page.index)
# Copy the data
dst_view.copy_from(src_view)

# Setup destination page metadata
dst_page.token_offset = 0 # Always start at beginning of new page

return dst_page

def __repr__(self):
# No need to lock for repr (list is internally synchronized).
free_pages = len(self.attn_page_free)
total_pages = len(self.attn_page_entries)
return (
f"PagePool({total_pages - free_pages}/{total_pages} pages in use: "
f"{100.0 * free_pages / total_pages}% free)"
)


############################## begin radix attention
Loading
Loading