Skip to content

Commit

Permalink
add sharded tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Oct 28, 2024
1 parent 9302f56 commit 9d0e19a
Showing 1 changed file with 167 additions and 19 deletions.
186 changes: 167 additions & 19 deletions sharktank/tests/layers/kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from sharktank.ops import replicate, unshard
from sharktank.ops import replicate, reshard_split, unshard
from sharktank.layers import *
from sharktank.types import *

Expand Down Expand Up @@ -149,19 +149,25 @@ def test_sharded_direct():
write_seq_length = seq_length - 5

# Write a prefill in:
write_ones = torch.full(
(bs, write_seq_length, attn_head_count // shard_count, attn_head_dim),
1.0,
dtype=torch.float32,
)
write_twos = torch.full(
(bs, write_seq_length, attn_head_count // shard_count, attn_head_dim),
2.0,
dtype=torch.float32,
write_ones = reshard_split(
torch.full(
(bs, write_seq_length, attn_head_count, attn_head_dim),
1.0,
dtype=torch.float32,
),
dim=2,
count=shard_count,
)

write_ones = SplitPrimitiveTensor(ts=[write_ones] * shard_count, shard_dim=2)
write_twos = SplitPrimitiveTensor(ts=[write_twos] * shard_count, shard_dim=2)
write_twos = reshard_split(
torch.full(
(bs, write_seq_length, attn_head_count, attn_head_dim),
2.0,
dtype=torch.float32,
),
dim=2,
count=shard_count,
)

cache.write(
allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1
Expand All @@ -186,16 +192,17 @@ def test_sharded_direct():
torch.testing.assert_close(unshard(write_twos), unshard(read_back[1]))

# Write timestep
write_threes = torch.full(
(bs, 1, attn_head_count // shard_count, attn_head_dim), 3.0, dtype=torch.float32
write_threes = reshard_split(
torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32),
dim=2,
count=shard_count,
)
write_fours = torch.full(
(bs, 1, attn_head_count // shard_count, attn_head_dim), 4.0, dtype=torch.float32
write_fours = reshard_split(
torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32),
dim=2,
count=shard_count,
)

write_threes = SplitPrimitiveTensor(ts=[write_threes] * shard_count, shard_dim=2)
write_fours = SplitPrimitiveTensor(ts=[write_fours] * shard_count, shard_dim=2)

write_pos = replicate(
torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count
)
Expand Down Expand Up @@ -352,3 +359,144 @@ def test_paged():

torch.testing.assert_close(check_concat_0, read_back[0])
torch.testing.assert_close(check_concat_1, read_back[1])


def test_sharded_paged():
bs = 4
seq_length = 24
attn_head_count = 8
attn_head_dim = 16
transformer_block_count = 4
block_seq_stride = 4
shard_count = 4
cache = PagedKVCache(
block_seq_stride=block_seq_stride,
transformer_block_count=transformer_block_count,
attn_head_count=attn_head_count,
attn_head_dim=attn_head_dim,
shard_count=shard_count,
dtype=torch.float32,
device=None,
)

write_seq_length = seq_length - 4
page_count = bs * seq_length // block_seq_stride
page_ids = torch.arange(page_count, dtype=torch.int64)
page_ids = page_ids.view(bs, seq_length // block_seq_stride)
page_ids = replicate(page_ids, shard_count)
write_page_ids = page_ids[:, : write_seq_length // block_seq_stride]

allocation = cache.allocate(page_count=page_count)

# Write a prefill in:
write_ones = reshard_split(
torch.full(
(bs, write_seq_length, attn_head_count, attn_head_dim),
1.0,
dtype=torch.float32,
),
dim=2,
count=shard_count,
)
write_twos = reshard_split(
torch.full(
(bs, write_seq_length, attn_head_count, attn_head_dim),
2.0,
dtype=torch.float32,
),
dim=2,
count=shard_count,
)

cache.write(
allocation,
cache_partitions=[write_ones, write_twos],
transformer_block_index=1,
page_ids=write_page_ids,
)

# Check the written values have updated:
empty_k = reshard_split(
torch.empty(
(bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32
),
dim=2,
count=shard_count,
)

empty_v = reshard_split(
torch.empty(
(bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32
),
dim=2,
count=shard_count,
)

read_empty = [empty_k, empty_v]

read_back = cache.read(
allocation,
read_into_partitions=read_empty,
transformer_block_index=1,
seq_len=write_seq_length,
page_ids=write_page_ids,
)
torch.testing.assert_close(unshard(write_ones), unshard(read_back[0]))
torch.testing.assert_close(unshard(write_twos), unshard(read_back[1]))

# Write timestep
write_threes = reshard_split(
torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32),
dim=2,
count=shard_count,
)

write_fours = reshard_split(
torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32),
dim=2,
count=shard_count,
)

write_pos = replicate(
torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count
)

cache.write_timestep(
allocation,
cache_partitions=[write_threes, write_fours],
transformer_block_index=1,
seq_positions=write_pos,
page_ids=page_ids,
)

empty_k = reshard_split(
torch.zeros(
(bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim),
dtype=torch.float32,
),
dim=2,
count=shard_count,
)

empty_v = reshard_split(
torch.zeros(
(bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim),
dtype=torch.float32,
),
dim=2,
count=shard_count,
)

read_back = cache.read(
allocation,
read_into_partitions=[empty_k, empty_v],
transformer_block_index=1,
seq_len=write_seq_length + 1,
page_ids=page_ids,
)

check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1)
check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1)

torch.testing.assert_close(check_concat_0, unshard(read_back[0]))
torch.testing.assert_close(check_concat_1, unshard(read_back[1]))

0 comments on commit 9d0e19a

Please sign in to comment.