Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llama] Update kv cache to have read/write functions #280

Merged
merged 7 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 113 additions & 19 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
attn_head_count: int,
attn_head_dim: int,
seq_length: int,
shard_count: int = 1,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
Expand All @@ -100,6 +101,7 @@ def __init__(
self.attn_head_count = attn_head_count
self.attn_head_dim = attn_head_dim
self.seq_length = seq_length
self.shard_count = shard_count
self.device = device
self.dtype = dtype

Expand All @@ -113,15 +115,109 @@ def allocate(self, *, bs: int) -> list[torch.Tensor]:

Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim]
"""
return [
allocations = [
torch.empty(
[bs, self.seq_length, self.attn_head_count, self.attn_head_dim],
[
bs,
self.seq_length,
self.attn_head_count,
self.attn_head_dim,
],
dtype=self.dtype,
device=self.device,
)
for _ in range(2 * self.transformer_block_count)
]

if self.shard_count == 1:
return allocations

return [
ops.reshard_split(allocation, dim=2, count=self.shard_count)
for allocation in allocations
]

def read(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
sogartar marked this conversation as resolved.
Show resolved Hide resolved
transformer_block_index: int,
seq_len: int,
page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
):
"""Reads cache partitions from the page table for the given page_ids.

Args:
state: State struct as returned from allocate().
read_into_partitions: List of cache partitions to read into in-place.
transformer_block_index: The index of the transformer block accessing
the cache.
page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids
to access.

Returns a tuple of cache partitions (i.e. k and v caches for the transformer
block), linearized. Note that this reference approach to reading by
materializing linearly may not be terribly efficient unless if the
compiler can fuse the gather.
"""
read_count = len(read_into_partitions)
reads = []
for i in range(read_count):
reads.append(
state[transformer_block_index * read_count + i][:, :seq_len, :, :]
)

return tuple(reads)

def write_timestep(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
# List of [bs, 1, attn_head_count, attn_head_dim]
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
# [bs]
seq_positions: Union[torch.Tensor, ReplicatedTensor],
# [bs, max_seqlen // block_pos_stride]
page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
):
"""Writes a single batched timestep across all cache partitions.

Note that this internally loops over the batch size, which cannot be
dynamic.
"""
bs, _, _, _ = cache_partitions[0].shape
update_count = len(cache_partitions)

for b in range(bs):
row_index = torch.tensor(b, dtype=torch.int64)
row_start_pos = seq_positions[row_index]

for i, update in enumerate(cache_partitions):
cache = state[transformer_block_index * update_count + i]
cache.index_put_((row_index, row_start_pos), update[row_index, 0])

def write(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
):
"""Writes cache partitions from a linear layout to the page table.

This is the inverse of the linear read. The same caveat applies if the
in-place scatter cannot be fused.
"""
update_count = len(cache_partitions)

for idx, update_src in enumerate(cache_partitions):
cache_dest = state[transformer_block_index * update_count + idx]
_, batch_seq_len, _, _ = update_src.shape
cache_dest[:, :batch_seq_len, :, :] = update_src


class PagedKVCache(BaseKVCache):
"""Implementation of a KV cache on top of a 'page table'.
Expand Down Expand Up @@ -238,31 +334,27 @@ def allocate(
"""Allocates tensor state for a page table for the given capacity in
pages.
"""
shards = [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
for _ in range(self.shard_count)
]

if self.shard_count == 1:
return [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
]
else:
shards = [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
for _ in range(self.shard_count)
]
return [SplitPrimitiveTensor(ts=shards, shard_dim=1)]
return shards

return [SplitPrimitiveTensor(ts=shards, shard_dim=1)]

def read(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
transformer_block_index: int,
seq_len: int,
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Reads cache partitions from the page table for the given page_ids.
Expand Down Expand Up @@ -331,6 +423,8 @@ def read_cache_partition(
for index, read_into_partition in enumerate(read_into_partitions):
read_cache_partition(index, read_into_partition)

return tuple([p[:, :seq_len, :] for p in read_into_partitions])

def write_timestep(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
Expand Down
156 changes: 53 additions & 103 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,16 @@ def forward(
# Full sequence length.
kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride

if self.cache.is_paged:
xk, xv = self.transact_cache_paged(
xk_cache_update=xk,
xv_cache_update=xv,
seq_block_ids=seq_block_ids,
kv_seq_len=kv_seq_len,
start_positions=start_positions,
cache_state=cache_state,
xk_temp=xk_temp,
xv_temp=xv_temp,
)
elif self.cache.is_direct:
xk, xv = self.transact_cache_direct(
xk_cache_update=xk,
xv_cache_update=xv,
start_positions=start_positions,
kv_seq_len=kv_seq_len,
cache_state=cache_state,
)
else:
raise NotImplementedError(f"Unsupported KV cache type: {type(self.cache)}")
xk, xv = self.transact_cache(
xk_cache_update=xk,
xv_cache_update=xv,
seq_block_ids=seq_block_ids,
kv_seq_len=kv_seq_len,
start_positions=start_positions,
cache_state=cache_state,
xk_temp=xk_temp,
xv_temp=xv_temp,
)

# Expand kv heads for GQA.
gqa_n_rep = self.head_count // self.head_count_kv
Expand Down Expand Up @@ -202,58 +191,20 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
h = h + attn_output
return h

def transact_cache_direct(
self,
*,
cache_state: list[torch.Tensor],
xk_cache_update: torch.Tensor,
xv_cache_update: torch.Tensor,
kv_seq_len: int,
start_positions: Optional[torch.Tensor] = None,
):
bs, batch_seq_len, _, _ = xk_cache_update.shape
cache_k = cache_state[self.block_index * 2]
cache_v = cache_state[self.block_index * 2 + 1]

if start_positions is None:
# Prefill. Write the entire cache.
cache_k[:, :batch_seq_len] = xk_cache_update
cache_v[:, :batch_seq_len] = xv_cache_update
return xk_cache_update, xv_cache_update
else:
# Decode. Write a single timestep.
# TODO: This needs to be reworked with index ops.
assert xk_cache_update.shape[1] == 1
assert xv_cache_update.shape[1] == 1
for b in range(bs):
# Make a tensor because indices must be all tensors, so we can avoid
# doing start_positions[row_index].item(), which generates a lot of SymInts.
row_index = torch.tensor(
b, dtype=torch.int64, device=xk_cache_update.device
)
row_start_pos = start_positions[row_index]
cache_k.index_put_(
(row_index, row_start_pos), xk_cache_update[row_index, 0]
)
cache_v.index_put_(
(row_index, row_start_pos), xv_cache_update[row_index, 0]
)
return cache_k[:, :kv_seq_len], cache_v[:, :kv_seq_len]

def transact_cache_paged(
def transact_cache(
self,
*,
xk_cache_update: torch.Tensor,
xv_cache_update: torch.Tensor,
cache_state: list[torch.Tensor],
# [bs, batch_seq_len // block_seq_stride]
seq_block_ids: torch.Tensor,
seq_block_ids: Optional[torch.Tensor],
kv_seq_len: int,
start_positions: Optional[torch.Tensor] = None,
xk_temp: Optional[torch.Tensor] = None,
xv_temp: Optional[torch.Tensor] = None,
):
cache = self.cache.paged
cache = self.cache
# Manage the cache.
if start_positions is None:
# Prefill: Write the entire cache.
Expand All @@ -264,46 +215,45 @@ def transact_cache_paged(
page_ids=seq_block_ids,
)
return xk_cache_update, xv_cache_update
else:
# Decode at ragged start positions.
# We need to initialize/read the K/V from the cache for the whole
# sequence. Note that at this point, it is possible to fork and
# use a memory efficient attention kernel that can do indirect
# reads, skipping this materialization. This path is taken for
# a decode step.
assert xk_temp is not None and xv_temp is not None
assert xk_cache_update.shape[1] == 1
assert xv_cache_update.shape[1] == 1
assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride

# Write our one updated cache row into the cache.
cache.write_timestep(
cache_state,
cache_partitions=[
xk_cache_update,
xv_cache_update,
],
transformer_block_index=self.block_index,
seq_positions=start_positions,
page_ids=seq_block_ids,
)

# Restore from the cache.
cache.read(
cache_state,
read_into_partitions=[
xk_temp[:, 0:kv_seq_len, ...],
xv_temp[:, 0:kv_seq_len, ...],
],
transformer_block_index=self.block_index,
page_ids=seq_block_ids,
)
# Decode at ragged start positions.
# We need to initialize/read the K/V from the cache for the whole
# sequence. Note that at this point, it is possible to fork and
# use a memory efficient attention kernel that can do indirect
# reads, skipping this materialization. This path is taken for
# a decode step.
assert xk_temp is not None and xv_temp is not None
assert xk_cache_update.shape[1] == 1
assert xv_cache_update.shape[1] == 1
assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride

# Write our one updated cache row into the cache.
cache.write_timestep(
cache_state,
cache_partitions=[
xk_cache_update,
xv_cache_update,
],
transformer_block_index=self.block_index,
seq_positions=start_positions,
page_ids=seq_block_ids,
)

# Restore from the cache.
xk, xv = cache.read(
cache_state,
read_into_partitions=[
xk_temp[:, 0:kv_seq_len, ...],
xv_temp[:, 0:kv_seq_len, ...],
],
transformer_block_index=self.block_index,
page_ids=seq_block_ids,
seq_len=kv_seq_len,
)

# For computation, we create a subview of the xk/xv tensors to have
# a sequence length covering the blocked size. This must include
# the newly added row (the caller is responsible for ensuring that
# every block has at least one row left). We'll compute on this
# ragged view and use an appropriate mask.
xk = xk_temp[:, 0:kv_seq_len, ...]
xv = xv_temp[:, 0:kv_seq_len, ...]
return xk, xv
# For computation, we create a subview of the xk/xv tensors to have
# a sequence length covering the blocked size. This must include
# the newly added row (the caller is responsible for ensuring that
# every block has at least one row left). We'll compute on this
# ragged view and use an appropriate mask.
return xk, xv
2 changes: 1 addition & 1 deletion sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ def _is_slicing_split_dim(self, key):
else:
# Any other collection is a indexing only dimension 0.
return self.shard_dim == 0
if len(key) < self.shard_dim:
if len(key) <= self.shard_dim:
return False
if not isinstance(key[self.shard_dim], slice):
return True
Expand Down
Loading
Loading