Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Oct 16, 2024
1 parent f36b3ff commit 3444f6d
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,27 @@ def allocate(self, *, bs: int) -> list[torch.Tensor]:
Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim]
"""
shards = [[torch.empty(
[bs, self.seq_length, self.attn_head_count // self.shard_count, self.attn_head_dim],
allocations = [
torch.empty(
[
bs,
self.seq_length,
self.attn_head_count,
self.attn_head_dim,
],
dtype=self.dtype,
device=self.device,
) for i in range(self.shard_count)]
)
for _ in range(2 * self.transformer_block_count)
]

if self.shard_count == 1:
return [shard[0] for shard in shards]

return [SplitPrimitiveTensor(ts=shrds, shard_dim=2) for shrds in shards]
return allocations

return [
ops.reshard_split(allocation, dim=2, count=self.shard_count)
for allocation in allocations
]

def read(
self,
Expand Down Expand Up @@ -156,7 +164,9 @@ def read(
read_count = len(read_into_partitions)
reads = []
for i in range(read_count):
reads.append(state[transformer_block_index * read_count + i][:, :seq_len, :, :])
reads.append(
state[transformer_block_index * read_count + i][:, :seq_len, :, :]
)

return tuple(reads)

Expand Down

0 comments on commit 3444f6d

Please sign in to comment.