Skip to content

Commit

Permalink
Tweaks.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jan 9, 2025
1 parent d305f40 commit 26f115f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
4 changes: 2 additions & 2 deletions moshi/moshi/modules/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def apply_rope(
koi = kr * roti + ki * rotr

dtype = q.dtype
qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1)
ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1)
qo = torch.stack([qor, qoi], dim=-1).to(dtype)
ko = torch.stack([kor, koi], dim=-1).to(dtype)

return qo.view(*dims, D), ko.view(*dims, D)

Expand Down
17 changes: 11 additions & 6 deletions moshi/moshi/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,13 @@ def __init__(
):
super().__init__()
self.capacity = capacity
self.register_buffer("cache", torch.zeros(
(2, batch_size, num_heads, capacity, dim_per_head),
self.register_buffer("k_cache", torch.zeros(
(batch_size, num_heads, capacity, dim_per_head),
device=device,
dtype=dtype,
))
self.register_buffer("v_cache", torch.zeros(
(batch_size, num_heads, capacity, dim_per_head),
device=device,
dtype=dtype,
))
Expand All @@ -246,11 +251,11 @@ def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult:
assert T > 0
indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset
indexes = indexes % self.capacity
self.cache[0].index_copy_(2, indexes, k)
self.cache[1].index_copy_(2, indexes, v)
self.k_cache[:, :, indexes] = k
self.v_cache[:, :, indexes] = v

keys = self.cache[0]
values = self.cache[1]
keys = self.k_cache
values = self.v_cache

indexes = torch.arange(
self.capacity, device=self.end_offset.device, dtype=torch.long
Expand Down

0 comments on commit 26f115f

Please sign in to comment.