diff --git a/thunder/tests/litgpt_model.py b/thunder/tests/litgpt_model.py index 6dbd35f2ef..0926715952 100644 --- a/thunder/tests/litgpt_model.py +++ b/thunder/tests/litgpt_model.py @@ -118,43 +118,6 @@ name_to_config = {config["name"]: config for config in configs} -class OverridenKVCache(nn.Module): - def __init__( - self, - k_shape: tuple[int, int, int, int], - v_shape: tuple[int, int, int, int], - device: torch.device | None = None, - dtype: torch.dtype | None = None, - ) -> None: - super().__init__() - self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False) - self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False) - - def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - # move the buffer to the activation dtype for when AMP is used - self.k = self.k.to(k.dtype) - self.v = self.v.to(v.dtype) - # update the cache - # NOTE: `torch._dynamo.is_compiling` is being deprecated, we should update this once all versions have `torch.compiler.is_compiling`. - is_compiling = ( - torch.compiler.is_compiling if hasattr(torch.compiler, "is_compiling") else torch._dynamo.is_compiling - ) - if is_compiling(): - # inductor doesn't support `index_add` with bfloat16 - k = self.k.index_copy_(2, input_pos, k) - v = self.v.index_copy_(2, input_pos, v) - return k, v - # See issue: "Support more indexing operators (index_copy and index_add)" - k = self.k = self.k.index_copy_(2, input_pos, k) - v = self.v = self.v.index_copy_(2, input_pos, v) - # THUNDER bug: cannot return self.k, self.v here (may be cuda graphs related - no minimum repro) - return k, v - - def reset_parameters(self) -> None: - torch.nn.init.zeros_(self.k) - torch.nn.init.zeros_(self.v) - - import litgpt # add the testing configurations