Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow the executorch compilation. #182

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion moshi/moshi/models/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(
freeze_encoder: bool = False,
freeze_quantizer: bool = False,
freeze_quantizer_level: int = -1,
torch_compile_encoder_decoder: bool = False,
torch_compile_encoder_decoder: bool = True,
):
super().__init__()
self.encoder = encoder
Expand Down
4 changes: 2 additions & 2 deletions moshi/moshi/modules/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def apply_rope(
koi = kr * roti + ki * rotr

dtype = q.dtype
qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1)
ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1)
qo = torch.stack([qor, qoi], dim=-1).to(dtype)
ko = torch.stack([kor, koi], dim=-1).to(dtype)

return qo.view(*dims, D), ko.view(*dims, D)

Expand Down
68 changes: 38 additions & 30 deletions moshi/moshi/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def from_kv(keys: torch.Tensor, values: torch.Tensor) -> "KVCacheResult":
return KVCacheResult(keys, values, positions)


class RingKVCache:
class RingKVCache(nn.Module):
"""Efficient streaming KVCache to be compatible with Cuda Graph.

Args:
Expand All @@ -228,13 +228,19 @@ def __init__(
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.capacity = capacity
self.cache = torch.zeros(
(2, batch_size, num_heads, capacity, dim_per_head),
self.register_buffer("k_cache", torch.zeros(
(batch_size, num_heads, capacity, dim_per_head),
device=device,
dtype=dtype,
)
self.end_offset = torch.zeros(1, device=device, dtype=torch.long)
))
self.register_buffer("v_cache", torch.zeros(
(batch_size, num_heads, capacity, dim_per_head),
device=device,
dtype=dtype,
))
self.register_buffer("end_offset", torch.zeros(1, device=device, dtype=torch.long))

def reset(self):
self.end_offset.zero_()
Expand All @@ -245,11 +251,11 @@ def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult:
assert T > 0
indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset
indexes = indexes % self.capacity
self.cache[0].index_copy_(2, indexes, k)
self.cache[1].index_copy_(2, indexes, v)
self.k_cache[:, :, indexes] = k
self.v_cache[:, :, indexes] = v

keys = self.cache[0]
values = self.cache[1]
keys = self.k_cache
values = self.v_cache

indexes = torch.arange(
self.capacity, device=self.end_offset.device, dtype=torch.long
Expand Down Expand Up @@ -280,13 +286,9 @@ def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult:

@dataclass
class _MHAState:
kv_cache: RingKVCache
offset: torch.Tensor
offset_cpu: int

def reset(self):
self.kv_cache.reset()
self.offset.zero_()
self.offset_cpu = 0


Expand Down Expand Up @@ -342,6 +344,21 @@ def __init__(
self.out_proj = nn.Linear(
embed_dim, mult * embed_dim, bias=False, **factory_kwargs
)
self.register_buffer("offset", torch.zeros(1, device=in_proj.weight.device, dtype=torch.long))
dim_per_head = self.embed_dim // self.num_heads
dtype = self.in_proj_weight.dtype
if self.context is None:
if self.weights_per_step:
capacity = self.weights_per_step
else:
raise RuntimeError(
"Cannot create a streaming KVCache without a context to estimate capacity."
)
else:
capacity = self.context
self.kv_cache = RingKVCache(
1, self.num_heads, dim_per_head, capacity, device, dtype
)

def _init_streaming_state(self, batch_size: int) -> _MHAState:
if self.context is None:
Expand All @@ -357,12 +374,7 @@ def _init_streaming_state(self, batch_size: int) -> _MHAState:
# TODO: the following estimation will not work great with FSDP.
dtype = self.in_proj_weight.dtype
dim_per_head = self.embed_dim // self.num_heads
kv_cache = RingKVCache(
batch_size, self.num_heads, dim_per_head, capacity, device, dtype
)
return _MHAState(
kv_cache,
offset=torch.zeros(1, device=device, dtype=torch.long),
offset_cpu=0,
)

Expand All @@ -371,7 +383,7 @@ def _complete_kv(self, k, v) -> KVCacheResult:
if state is None:
return KVCacheResult.from_kv(k, v)
else:
return state.kv_cache.complete(k, v)
return self.kv_cache.complete(k, v)

def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
state = self._streaming_state
Expand All @@ -382,7 +394,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
offset_cpu = 0
else:
assert self.causal, "Streaming only available for causal"
offset = state.offset
offset = self.offset
offset_cpu = state.offset_cpu

if self.weights_per_step:
Expand Down Expand Up @@ -418,7 +430,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
else:
x = self.out_proj(x)
if state is not None:
state.offset.add_(T)
self.offset.add_(T)
state.offset_cpu += T
return x

Expand Down Expand Up @@ -581,9 +593,6 @@ def _sa_block(self, x: torch.Tensor):
return x_orig + self.layer_scale_1(update)

def forward(self, x: torch.Tensor):
with ExitStack() as stack:
if x.device.type != 'cuda':
stack.enter_context(no_compile())
x = self._sa_block(x)
x = self._ff_block(x)
state = self._streaming_state
Expand All @@ -594,10 +603,8 @@ def forward(self, x: torch.Tensor):

@dataclass
class _TransformerState:
offset: torch.Tensor

def reset(self):
self.offset.zero_()
pass


class StreamingTransformer(StreamingModule[_TransformerState]):
Expand Down Expand Up @@ -666,10 +673,11 @@ def __init__(
**kwargs,
)
)
self.register_buffer("offset", torch.zeros(1, device=device, dtype=torch.long))

def _init_streaming_state(self, batch_size: int) -> _TransformerState:
device = next(self.parameters()).device
return _TransformerState(offset=torch.zeros(1, device=device, dtype=torch.long))
return _TransformerState()

def forward(self, x: torch.Tensor, *args, **kwargs):
B, T, C = x.shape
Expand All @@ -678,7 +686,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs):
if state is None:
offset = torch.zeros(1, dtype=torch.long, device=x.device)
else:
offset = state.offset
offset = self.offset

if self.positional_embedding in {"sin", "sin_rope"}:
positions = torch.arange(T, device=x.device).view(1, -1, 1)
Expand All @@ -692,7 +700,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs):
x = layer(x, *args, **kwargs)

if state is not None:
state.offset.add_(T)
self.offset.add_(T)
return x


Expand Down
Loading