Skip to content

Commit

Permalink
fix accidental reintroduction of overrridden kvcache (#813)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Jul 20, 2024
1 parent 3c6666f commit d15b64c
Showing 1 changed file with 0 additions and 37 deletions.
37 changes: 0 additions & 37 deletions thunder/tests/litgpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,43 +118,6 @@
name_to_config = {config["name"]: config for config in configs}


class OverridenKVCache(nn.Module):
def __init__(
self,
k_shape: tuple[int, int, int, int],
v_shape: tuple[int, int, int, int],
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)

def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# move the buffer to the activation dtype for when AMP is used
self.k = self.k.to(k.dtype)
self.v = self.v.to(v.dtype)
# update the cache
# NOTE: `torch._dynamo.is_compiling` is being deprecated, we should update this once all versions have `torch.compiler.is_compiling`.
is_compiling = (
torch.compiler.is_compiling if hasattr(torch.compiler, "is_compiling") else torch._dynamo.is_compiling
)
if is_compiling():
# inductor doesn't support `index_add` with bfloat16
k = self.k.index_copy_(2, input_pos, k)
v = self.v.index_copy_(2, input_pos, v)
return k, v
# See issue: "Support more indexing operators (index_copy and index_add)"
k = self.k = self.k.index_copy_(2, input_pos, k)
v = self.v = self.v.index_copy_(2, input_pos, v)
# THUNDER bug: cannot return self.k, self.v here (may be cuda graphs related - no minimum repro)
return k, v

def reset_parameters(self) -> None:
torch.nn.init.zeros_(self.k)
torch.nn.init.zeros_(self.v)


import litgpt

# add the testing configurations
Expand Down

0 comments on commit d15b64c

Please sign in to comment.