Skip to content

Commit

Permalink
LLM: support int4 fp16 chatglm2-6b 8k input. (#10648)
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalapotter authored Apr 7, 2024
1 parent ab87b6a commit 1a9b820
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,10 +512,23 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
query_layer = query_layer.permute(1, 2, 0, 3)
L, S = query_layer.shape[2], key_layer.shape[2]
if attention_mask is None and L == S:
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
key_layer,
value_layer,
is_causal=True).to(key_layer.dtype)
# split tensor for memory block limitation
# support fp16 and set input length threshold at 5000 for now
if query_layer.dtype == torch.float16 and L >= 5000:
# split first dim 32 -> 8
query_sp = torch.split(query_layer.to(key_layer.dtype), 8, dim=1)
key_sp = torch.split(key_layer, 8, dim=1)
value_sp = torch.split(value_layer, 8, dim=1)
results = []
for q, k, v in zip(query_sp, key_sp, value_sp):
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
results.append(result)
context_layer = torch.cat(results, dim=1)
else:
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
key_layer,
value_layer,
is_causal=True).to(key_layer.dtype)
else:
if use_esimd_sdp(query_layer.shape[2], key_layer.shape[2],
query_layer.shape[-1], query_layer):
Expand Down

0 comments on commit 1a9b820

Please sign in to comment.