Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ascend]feat: support kv int8 #103

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ on:
env:
CI_PATH: '/data2/wugeshui/GitHub/${{ github.repository }}/${GITHUB_RUN_NUMBER}'
LMDEPLOY_PATH: '/data2/wugeshui/GitHub/lmdeploy'
LMDEPLOY_COMMIT_OR_BRANCH: 'main'
LMDEPLOY_COMMIT_OR_BRANCH: 'ascend_kv_int8'
REPORT_DIR: /data2/wugeshui/GitHub/ci_log/test_reports

concurrency:
Expand Down
18 changes: 17 additions & 1 deletion dlinfer/graph/dicp/vendor/AtbGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,17 @@ def apply_rotary_pos_emb(self, q, k, cos, sin, q_out, k_out):
return out

@register_conversion("torch.ops.dlinfer.fill_kv_cache.default")
def fill_kv_cache(self, key, value, key_cache, value_cache, kv_indices):
def fill_kv_cache(
self,
key,
value,
key_cache,
value_cache,
kv_indices,
k_scales_zeros,
v_scales_zeros,
quant_bits,
):
key_cache_shape = key_cache.node.meta["val"].shape
key_shape = key.node.meta["val"].shape
key_cache_reshaped = self.get_proxy(
Expand Down Expand Up @@ -171,6 +181,9 @@ def paged_attention_decode(
softmax_scale,
alibi_slopes,
attn_output,
kv_scales,
kv_zeros,
quant_bits,
):
q_head_num = num_q_heads
kv_head_num = num_kv_heads
Expand Down Expand Up @@ -370,6 +383,9 @@ def prefill_attention(
block_size,
mask,
is_unpaged_prefill,
kv_scales,
kv_zeros,
quant_bits,
):
# k_cache = self.get_proxy(atb_op.View, (k_cache, [-1, block_size, num_kv_heads, kv_head_size]))
# v_cache = self.get_proxy(atb_op.View, (v_cache, [-1, block_size, num_kv_heads, kv_head_size]))
Expand Down
43 changes: 42 additions & 1 deletion dlinfer/ops/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,24 @@ def prefill_attention(
)


@register_custom_op("dlinfer::fill_kv_cache", ["key_cache", "value_cache"])
@register_custom_op(
"dlinfer::fill_kv_cache",
["key_cache", "value_cache"],
default_value={
"k_scales_zeros": tuple(),
"v_scales_zeros": tuple(),
"quant_bits": 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否可以添加默认值:
"k_scales_zeros": tuple()
"v_scales_zeros": tuple()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加

},
)
def fill_kv_cache(
key: Tensor,
value: Tensor,
key_cache: Tensor,
value_cache: Tensor,
kv_indices: Tensor,
k_scales_zeros: Sequence[Optional[Tensor]],
v_scales_zeros: Sequence[Optional[Tensor]],
quant_bits: int,
) -> Tuple[Tensor, Tensor]:
"""
Fills the key-value cache with the provided key and value tensors.
Expand All @@ -168,6 +179,9 @@ def fill_kv_cache(
key_cache (Tensor): The existing key cache tensor.
value_cache (Tensor): The existing value cache tensor.
kv_indices (Tensor): The indices specifying where to store the key and value in the cache.
k_scales_zeros (Sequence[Optional[Tensor]]): The scales and zeros used to quantify key.
v_scales_zeros (Sequence[Optional[Tensor]]): The scales and zeros used to quantify value.
quant_bits (int): The bits which k/v is quantized into.

Returns:
Tuple[Tensor, Tensor]:
Expand All @@ -180,6 +194,9 @@ def fill_kv_cache(
key_cache,
value_cache,
kv_indices,
k_scales_zeros,
v_scales_zeros,
quant_bits,
)


Expand All @@ -190,6 +207,9 @@ def fill_kv_cache(
"softmax_scale": None,
"alibi_slopes": None,
"attn_output": None,
"kv_scales": None,
"kv_zeros": None,
"quant_bits": 0,
},
)
def paged_decode_attention(
Expand All @@ -205,6 +225,9 @@ def paged_decode_attention(
softmax_scale: Optional[float],
alibi_slopes: Optional[Sequence[float]],
attn_output: Optional[Tensor],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 kv_zeros 的类型为啥和 fill_kv_cache 中的 k_scales_zeros,v_scales_zeros 类型不一致?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fill_kv_cache中是为了避免slice

quant_bits: Optional[int],
) -> Tensor:
"""
Computes the multi-head attention over the query, key, and value tensors.
Expand All @@ -224,6 +247,9 @@ def paged_decode_attention(
softmax_scale (Optional[float]): The scale factor to apply to the attention logits before the softmax.
alibi_slopes (Optional[Sequence[float]]): The slopes for the ALiBi attention bias, one for each head.
attn_output (Optional[Tensor]): The computed attention output tensor.
kv_scales (Optional[Tensor]): The quantization factors for key and value.
kv_zeros (Optional[Tensor]): The quantization offset for key and value.
quant_bits (Optional[int]): The bits which k/v is quantized into.

Returns:
Tensor: The computed attention output tensor, alias of attn_output.
Expand All @@ -241,6 +267,9 @@ def paged_decode_attention(
softmax_scale,
alibi_slopes,
attn_output,
kv_scales,
kv_zeros,
quant_bits,
)


Expand All @@ -251,6 +280,9 @@ def paged_decode_attention(
"softmax_scale": None,
"alibi_slopes": None,
"attn_output": None,
"kv_scales": None,
"kv_zeros": None,
"quant_bits": 0,
},
)
def paged_prefill_attention(
Expand All @@ -268,6 +300,9 @@ def paged_prefill_attention(
softmax_scale: Optional[float],
alibi_slopes: Optional[Sequence[float]],
attn_output: Optional[Tensor],
kv_scales: Tensor,
kv_zeros: Tensor,
quant_bits: int,
) -> Tensor:
"""
Computes the multi-head attention over the query, key, and value tensors.
Expand All @@ -289,6 +324,9 @@ def paged_prefill_attention(
softmax_scale (Optional[float]): The scale factor to apply to the attention logits before the softmax.
alibi_slopes (Optional[Sequence[float]]): The slopes for the ALiBi attention bias, one for each head.
attn_output (Optional[Tensor]): The computed attention output tensor.
kv_scales (Optional[Tensor]): The quantization factors for key and value.
kv_zeros (Optional[Tensor]): The quantization offset for key and value.
quant_bits (Optional[int]): The bits which k/v is quantized into.

Returns:
Tensor: The computed attention output tensor, alias of attn_output.
Expand All @@ -308,6 +346,9 @@ def paged_prefill_attention(
softmax_scale,
alibi_slopes,
attn_output,
kv_scales,
kv_zeros,
quant_bits,
)


Expand Down
30 changes: 25 additions & 5 deletions dlinfer/vendor/ascend/torch_npu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,11 @@ def fill_kv_cache(
key_cache: Tensor,
value_cache: Tensor,
kv_indices: Tensor,
k_scales_zeros: Sequence[Optional[Tensor]],
v_scales_zeros: Sequence[Optional[Tensor]],
quant_bits: int,
) -> Tuple[Tensor, Tensor]:
head, dim = key.shape[1:]
_, head, dim = key.shape
block_num, block_size = key_cache.shape[:2]
block_total = block_num * block_size

Expand All @@ -132,6 +135,17 @@ def fill_kv_cache(
value = value.contiguous()
kv_indices = kv_indices.view(-1, 1)

if quant_bits == 8:

def quant_int8(x, x_scale, x_offset):
quantized = (
((x / x_scale) - x_offset).round().clamp(-128, 127).to(torch.int8)
)
return quantized

key = quant_int8(key, k_scales_zeros[0], k_scales_zeros[1])
value = quant_int8(value, v_scales_zeros[0], v_scales_zeros[1])

key_cache_reshaped = key_cache.view(block_total, head, dim)
value_cache_reshaped = value_cache.view(block_total, head, dim)
torch.ops.npu.npu_scatter_nd_update_(key_cache_reshaped, kv_indices, key)
Expand Down Expand Up @@ -167,6 +181,9 @@ def paged_decode_attention(
softmax_scale: Optional[float],
alibi_slopes: Optional[Sequence[float]],
attn_output: Optional[Tensor],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
quant_bits: Optional[int],
) -> Tensor:
if alibi_slopes is not None:
raise RuntimeError(
Expand All @@ -188,8 +205,8 @@ def paged_decode_attention(
padding_mask=None,
atten_mask=None,
actual_seq_lengths=kv_seq_len.tolist(),
antiquant_scale=None,
antiquant_offset=None,
antiquant_scale=kv_scales,
antiquant_offset=kv_zeros,
block_table=block_table,
dequant_scale1=None,
quant_scale1=None,
Expand Down Expand Up @@ -222,6 +239,9 @@ def paged_prefill_attention(
softmax_scale: Optional[float],
alibi_slopes: Optional[Sequence[float]],
attn_output: Optional[Tensor],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
quant_bits: Optional[int],
) -> Tensor:
if alibi_slopes is not None:
raise RuntimeError(
Expand All @@ -245,8 +265,8 @@ def paged_prefill_attention(
padding_mask=None,
atten_mask=attn_mask[0],
actual_seq_lengths=kv_seq_len_list,
antiquant_scale=None,
antiquant_offset=None,
antiquant_scale=kv_scales,
antiquant_offset=kv_zeros,
block_table=block_table,
dequant_scale1=None,
quant_scale1=None,
Expand Down
9 changes: 9 additions & 0 deletions dlinfer/vendor/maca/maca_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def fill_kv_cache(
key_cache: Tensor,
value_cache: Tensor,
kv_indices: Tensor,
k_scales_zeros: Sequence[Optional[Tensor]],
v_scales_zeros: Sequence[Optional[Tensor]],
quant_bits: int,
) -> Tuple[Tensor, Tensor]:
kv_indices = kv_indices.squeeze(-1)
maca_ext_ops.reshape_and_cache_new(
Expand All @@ -204,6 +207,9 @@ def paged_decode_attention(
softmax_scale: Optional[float],
alibi_slopes: Optional[Sequence[float]],
attn_output: Optional[Tensor],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
quant_bits: Optional[int],
) -> Tensor:
if alibi_slopes is not None:
raise RuntimeError("paged_decode_attention does not support alibi_slopes yet")
Expand Down Expand Up @@ -269,6 +275,9 @@ def paged_prefill_attention(
softmax_scale: Optional[float],
alibi_slopes: Optional[Sequence[float]],
attn_output: Optional[Tensor],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
quant_bits: Optional[int],
) -> Tensor:
dim = query.size(-1)
batch_size = block_table.size(0)
Expand Down
52 changes: 52 additions & 0 deletions docs/quant/ascend_kv_quant.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

# KV Cache量化

目前在华为Atlas 800T A2设备,由于算子功能限制,在算子模式下,仅支持离线量化。

## KV Cache量化前提

- **依赖**

```shell
torch==2.1.0
torchvision==0.16.0
torch-npu==2.1.0.post6
```

- **工具**

```shell
amct_pytorch==0.22.2(Ascend-cann-amct_8.0.RC2)
```

## KV Cache量化示例

在当前目录执行如下命令,得到量化因子记录文件,用户根据实际情况修改示例程序中的model_path(VL模型需要用其语言模型的权重)和dataset_path,并根据模型结构修改quant_layers。

```python
python3 ascend_scales_offsets.py
```

推理成功后,在当前目录会生成量化日志文件./amct_log/amct_pytorch.log和./outputs文件夹,该文件夹内包含以下内容:

- **config.json**:量化配置文件,描述了如何对模型中的每一层进行量化。
- **record.txt**:量化因子记录文件。

用户在使用lmdeploy时,通过环境变量ASCEND_QUANT_RECORD_FILE指定量化因子路径,并通过参数quant_policy=8,即可使用量化因子记录文件完成推理。
示例代码如下:

```python
import lmdeploy
from lmdeploy import PytorchEngineConfig
if __name__ == "__main__":
pipe = lmdeploy.pipeline("/path_to_model",
backend_config = PytorchEngineConfig(tp=1,
cache_max_entry_count=0.4, device_type="ascend",
eager_mode=True, quant_policy=8))
question = ["Shanghai is", "Please introduce China", "How are you?"]
response = pipe(question, request_output_len=256, do_preprocess=False)
for idx, r in enumerate(response):
print(f"Q: {question[idx]}")
print(f"A: {r.text}")
print()
```
Loading