diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 251e55814f..d71ffbc460 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -509,8 +509,12 @@ def forward( dim=-2, ) if sliding_window > 0 and key.size(2) > sliding_window: - self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][:, :, 1:, :] - self.layer_cache[1]["values"] = self.layer_cache[1]["values"][:, :, 1:, :] + self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][ + :, :, 1:, : + ] + self.layer_cache[1]["values"] = self.layer_cache[1]["values"][ + :, :, 1:, : + ] context = self.flash_attn_with_kvcache( query.transpose(1, 2), self.layer_cache[1]["keys"].transpose(1, 2),