diff --git a/moshi/moshi/models/compression.py b/moshi/moshi/models/compression.py index 7a790d5..d98f007 100644 --- a/moshi/moshi/models/compression.py +++ b/moshi/moshi/models/compression.py @@ -143,7 +143,7 @@ def __init__( freeze_encoder: bool = False, freeze_quantizer: bool = False, freeze_quantizer_level: int = -1, - torch_compile_encoder_decoder: bool = False, + torch_compile_encoder_decoder: bool = True, ): super().__init__() self.encoder = encoder diff --git a/moshi/moshi/modules/rope.py b/moshi/moshi/modules/rope.py index e3cd115..72861f7 100644 --- a/moshi/moshi/modules/rope.py +++ b/moshi/moshi/modules/rope.py @@ -62,8 +62,8 @@ def apply_rope( koi = kr * roti + ki * rotr dtype = q.dtype - qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1) - ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1) + qo = torch.stack([qor, qoi], dim=-1).to(dtype) + ko = torch.stack([kor, koi], dim=-1).to(dtype) return qo.view(*dims, D), ko.view(*dims, D) diff --git a/moshi/moshi/modules/transformer.py b/moshi/moshi/modules/transformer.py index 8c79a2a..f82e019 100644 --- a/moshi/moshi/modules/transformer.py +++ b/moshi/moshi/modules/transformer.py @@ -208,7 +208,7 @@ def from_kv(keys: torch.Tensor, values: torch.Tensor) -> "KVCacheResult": return KVCacheResult(keys, values, positions) -class RingKVCache: +class RingKVCache(nn.Module): """Efficient streaming KVCache to be compatible with Cuda Graph. Args: @@ -228,13 +228,19 @@ def __init__( device: torch.device = torch.device("cuda"), dtype: torch.dtype = torch.bfloat16, ): + super().__init__() self.capacity = capacity - self.cache = torch.zeros( - (2, batch_size, num_heads, capacity, dim_per_head), + self.register_buffer("k_cache", torch.zeros( + (batch_size, num_heads, capacity, dim_per_head), device=device, dtype=dtype, - ) - self.end_offset = torch.zeros(1, device=device, dtype=torch.long) + )) + self.register_buffer("v_cache", torch.zeros( + (batch_size, num_heads, capacity, dim_per_head), + device=device, + dtype=dtype, + )) + self.register_buffer("end_offset", torch.zeros(1, device=device, dtype=torch.long)) def reset(self): self.end_offset.zero_() @@ -245,11 +251,11 @@ def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult: assert T > 0 indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset indexes = indexes % self.capacity - self.cache[0].index_copy_(2, indexes, k) - self.cache[1].index_copy_(2, indexes, v) + self.k_cache[:, :, indexes] = k + self.v_cache[:, :, indexes] = v - keys = self.cache[0] - values = self.cache[1] + keys = self.k_cache + values = self.v_cache indexes = torch.arange( self.capacity, device=self.end_offset.device, dtype=torch.long @@ -280,13 +286,9 @@ def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult: @dataclass class _MHAState: - kv_cache: RingKVCache - offset: torch.Tensor offset_cpu: int def reset(self): - self.kv_cache.reset() - self.offset.zero_() self.offset_cpu = 0 @@ -342,6 +344,21 @@ def __init__( self.out_proj = nn.Linear( embed_dim, mult * embed_dim, bias=False, **factory_kwargs ) + self.register_buffer("offset", torch.zeros(1, device=in_proj.weight.device, dtype=torch.long)) + dim_per_head = self.embed_dim // self.num_heads + dtype = self.in_proj_weight.dtype + if self.context is None: + if self.weights_per_step: + capacity = self.weights_per_step + else: + raise RuntimeError( + "Cannot create a streaming KVCache without a context to estimate capacity." + ) + else: + capacity = self.context + self.kv_cache = RingKVCache( + 1, self.num_heads, dim_per_head, capacity, device, dtype + ) def _init_streaming_state(self, batch_size: int) -> _MHAState: if self.context is None: @@ -357,12 +374,7 @@ def _init_streaming_state(self, batch_size: int) -> _MHAState: # TODO: the following estimation will not work great with FSDP. dtype = self.in_proj_weight.dtype dim_per_head = self.embed_dim // self.num_heads - kv_cache = RingKVCache( - batch_size, self.num_heads, dim_per_head, capacity, device, dtype - ) return _MHAState( - kv_cache, - offset=torch.zeros(1, device=device, dtype=torch.long), offset_cpu=0, ) @@ -371,7 +383,7 @@ def _complete_kv(self, k, v) -> KVCacheResult: if state is None: return KVCacheResult.from_kv(k, v) else: - return state.kv_cache.complete(k, v) + return self.kv_cache.complete(k, v) def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): state = self._streaming_state @@ -382,7 +394,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): offset_cpu = 0 else: assert self.causal, "Streaming only available for causal" - offset = state.offset + offset = self.offset offset_cpu = state.offset_cpu if self.weights_per_step: @@ -418,7 +430,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): else: x = self.out_proj(x) if state is not None: - state.offset.add_(T) + self.offset.add_(T) state.offset_cpu += T return x @@ -581,9 +593,6 @@ def _sa_block(self, x: torch.Tensor): return x_orig + self.layer_scale_1(update) def forward(self, x: torch.Tensor): - with ExitStack() as stack: - if x.device.type != 'cuda': - stack.enter_context(no_compile()) x = self._sa_block(x) x = self._ff_block(x) state = self._streaming_state @@ -594,10 +603,8 @@ def forward(self, x: torch.Tensor): @dataclass class _TransformerState: - offset: torch.Tensor - def reset(self): - self.offset.zero_() + pass class StreamingTransformer(StreamingModule[_TransformerState]): @@ -666,10 +673,11 @@ def __init__( **kwargs, ) ) + self.register_buffer("offset", torch.zeros(1, device=device, dtype=torch.long)) def _init_streaming_state(self, batch_size: int) -> _TransformerState: device = next(self.parameters()).device - return _TransformerState(offset=torch.zeros(1, device=device, dtype=torch.long)) + return _TransformerState() def forward(self, x: torch.Tensor, *args, **kwargs): B, T, C = x.shape @@ -678,7 +686,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs): if state is None: offset = torch.zeros(1, dtype=torch.long, device=x.device) else: - offset = state.offset + offset = self.offset if self.positional_embedding in {"sin", "sin_rope"}: positions = torch.arange(T, device=x.device).view(1, -1, 1) @@ -692,7 +700,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs): x = layer(x, *args, **kwargs) if state is not None: - state.offset.add_(T) + self.offset.add_(T) return x