From 7d87c2886be945bb882c7cfc5311b374e2730958 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 14 Oct 2024 22:41:18 -0700 Subject: [PATCH] cleanup --- sharktank/sharktank/layers/kv_cache.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index b65a8889a..b02ca4fee 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -115,19 +115,27 @@ def allocate(self, *, bs: int) -> list[torch.Tensor]: Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim] """ - shards = [[torch.empty( - [bs, self.seq_length, self.attn_head_count // self.shard_count, self.attn_head_dim], + allocations = [ + torch.empty( + [ + bs, + self.seq_length, + self.attn_head_count, + self.attn_head_dim, + ], dtype=self.dtype, device=self.device, - ) for i in range(self.shard_count)] + ) for _ in range(2 * self.transformer_block_count) ] if self.shard_count == 1: - return [shard[0] for shard in shards] - - return [SplitPrimitiveTensor(ts=shrds, shard_dim=2) for shrds in shards] + return allocations + return [ + ops.reshard_split(allocation, dim=2, count=self.shard_count) + for allocation in allocations + ] def read( self, @@ -156,7 +164,9 @@ def read( read_count = len(read_into_partitions) reads = [] for i in range(read_count): - reads.append(state[transformer_block_index * read_count + i][:, :seq_len, :, :]) + reads.append( + state[transformer_block_index * read_count + i][:, :seq_len, :, :] + ) return tuple(reads)