Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
fix generate cache of glm models.
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Jul 22, 2024
1 parent 23f112f commit 0e64974
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 10 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ For users with NVIDIA Ampere or newer GPU architectures, the `--tf32` option can
+ Quantization with Qwen2 have no effect (same with transformers).
+ Applying quantization with DoRA will result in higher memory and computation cost (same with PEFT).
+ Sliding window attention with generate cache may product abnormal output.
+ ChatGLM models with generate cache may product abnormal output.

## Installation

Expand Down
7 changes: 2 additions & 5 deletions mlora/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,17 +425,14 @@ def generate(
)
)

if cache_implementation is None:
if use_cache and cache_implementation is None:
cache_implementation = model.model_.cache_implementation()
if use_cache and cache_implementation is None:
if cache_implementation is None:
logging.warn(
"Cache disabled by model, use cache_implementation to force enable."
)
use_cache = False

if use_cache is None and cache_implementation is not None:
use_cache = True

packed_outputs: Dict[str, List] = {}

while True:
Expand Down
5 changes: 1 addition & 4 deletions mlora/models/modeling_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def forward(

# apply relative positional encoding (rotary embedding)
if self.rotary_pos_emb is not None:
rotary_pos_emb = self.rotary_pos_emb[None, : hidden_states.shape[1]]
rotary_pos_emb = self.rotary_pos_emb[None, cache_position]
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

Expand Down Expand Up @@ -748,9 +748,6 @@ def causal_mask(
) -> torch.Tensor:
return self.get_masks(input_tensor, past_key_values, attention_mask)

def cache_implementation(self) -> str:
return None

def model_config(self) -> GLMConfig:
return self.config_

Expand Down

0 comments on commit 0e64974

Please sign in to comment.