-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ascend]feat: support kv int8 #103
Open
yao-fengchen
wants to merge
7
commits into
DeepLink-org:main
Choose a base branch
from
yao-fengchen:ascend_kv_int8
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
fd5db58
[ascend]feat: support kv int8 quant
yao-fengchen d00aed3
update doc
yao-fengchen 5eb0894
format code
yao-fengchen 8bbec89
update code
yao-fengchen 82a8a01
update params
yao-fengchen b8fc1b3
test ascend_kv_int8
yao-fengchen c200c6e
update docs
yao-fengchen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -151,13 +151,24 @@ def prefill_attention( | |
) | ||
|
||
|
||
@register_custom_op("dlinfer::fill_kv_cache", ["key_cache", "value_cache"]) | ||
@register_custom_op( | ||
"dlinfer::fill_kv_cache", | ||
["key_cache", "value_cache"], | ||
default_value={ | ||
"k_scales_zeros": tuple(), | ||
"v_scales_zeros": tuple(), | ||
"quant_bits": 0, | ||
}, | ||
) | ||
def fill_kv_cache( | ||
key: Tensor, | ||
value: Tensor, | ||
key_cache: Tensor, | ||
value_cache: Tensor, | ||
kv_indices: Tensor, | ||
k_scales_zeros: Sequence[Optional[Tensor]], | ||
v_scales_zeros: Sequence[Optional[Tensor]], | ||
quant_bits: int, | ||
) -> Tuple[Tensor, Tensor]: | ||
""" | ||
Fills the key-value cache with the provided key and value tensors. | ||
|
@@ -168,6 +179,9 @@ def fill_kv_cache( | |
key_cache (Tensor): The existing key cache tensor. | ||
value_cache (Tensor): The existing value cache tensor. | ||
kv_indices (Tensor): The indices specifying where to store the key and value in the cache. | ||
k_scales_zeros (Sequence[Optional[Tensor]]): The scales and zeros used to quantify key. | ||
v_scales_zeros (Sequence[Optional[Tensor]]): The scales and zeros used to quantify value. | ||
quant_bits (int): The bits which k/v is quantized into. | ||
|
||
Returns: | ||
Tuple[Tensor, Tensor]: | ||
|
@@ -180,6 +194,9 @@ def fill_kv_cache( | |
key_cache, | ||
value_cache, | ||
kv_indices, | ||
k_scales_zeros, | ||
v_scales_zeros, | ||
quant_bits, | ||
) | ||
|
||
|
||
|
@@ -190,6 +207,9 @@ def fill_kv_cache( | |
"softmax_scale": None, | ||
"alibi_slopes": None, | ||
"attn_output": None, | ||
"kv_scales": None, | ||
"kv_zeros": None, | ||
"quant_bits": 0, | ||
}, | ||
) | ||
def paged_decode_attention( | ||
|
@@ -205,6 +225,9 @@ def paged_decode_attention( | |
softmax_scale: Optional[float], | ||
alibi_slopes: Optional[Sequence[float]], | ||
attn_output: Optional[Tensor], | ||
kv_scales: Optional[Tensor], | ||
kv_zeros: Optional[Tensor], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里 kv_zeros 的类型为啥和 fill_kv_cache 中的 k_scales_zeros,v_scales_zeros 类型不一致? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fill_kv_cache中是为了避免slice |
||
quant_bits: Optional[int], | ||
) -> Tensor: | ||
""" | ||
Computes the multi-head attention over the query, key, and value tensors. | ||
|
@@ -224,6 +247,9 @@ def paged_decode_attention( | |
softmax_scale (Optional[float]): The scale factor to apply to the attention logits before the softmax. | ||
alibi_slopes (Optional[Sequence[float]]): The slopes for the ALiBi attention bias, one for each head. | ||
attn_output (Optional[Tensor]): The computed attention output tensor. | ||
kv_scales (Optional[Tensor]): The quantization factors for key and value. | ||
kv_zeros (Optional[Tensor]): The quantization offset for key and value. | ||
quant_bits (Optional[int]): The bits which k/v is quantized into. | ||
|
||
Returns: | ||
Tensor: The computed attention output tensor, alias of attn_output. | ||
|
@@ -241,6 +267,9 @@ def paged_decode_attention( | |
softmax_scale, | ||
alibi_slopes, | ||
attn_output, | ||
kv_scales, | ||
kv_zeros, | ||
quant_bits, | ||
) | ||
|
||
|
||
|
@@ -251,6 +280,9 @@ def paged_decode_attention( | |
"softmax_scale": None, | ||
"alibi_slopes": None, | ||
"attn_output": None, | ||
"kv_scales": None, | ||
"kv_zeros": None, | ||
"quant_bits": 0, | ||
}, | ||
) | ||
def paged_prefill_attention( | ||
|
@@ -268,6 +300,9 @@ def paged_prefill_attention( | |
softmax_scale: Optional[float], | ||
alibi_slopes: Optional[Sequence[float]], | ||
attn_output: Optional[Tensor], | ||
kv_scales: Tensor, | ||
kv_zeros: Tensor, | ||
quant_bits: int, | ||
) -> Tensor: | ||
""" | ||
Computes the multi-head attention over the query, key, and value tensors. | ||
|
@@ -289,6 +324,9 @@ def paged_prefill_attention( | |
softmax_scale (Optional[float]): The scale factor to apply to the attention logits before the softmax. | ||
alibi_slopes (Optional[Sequence[float]]): The slopes for the ALiBi attention bias, one for each head. | ||
attn_output (Optional[Tensor]): The computed attention output tensor. | ||
kv_scales (Optional[Tensor]): The quantization factors for key and value. | ||
kv_zeros (Optional[Tensor]): The quantization offset for key and value. | ||
quant_bits (Optional[int]): The bits which k/v is quantized into. | ||
|
||
Returns: | ||
Tensor: The computed attention output tensor, alias of attn_output. | ||
|
@@ -308,6 +346,9 @@ def paged_prefill_attention( | |
softmax_scale, | ||
alibi_slopes, | ||
attn_output, | ||
kv_scales, | ||
kv_zeros, | ||
quant_bits, | ||
) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
|
||
# KV Cache量化 | ||
|
||
目前在华为Atlas 800T A2设备,由于算子功能限制,在算子模式下,仅支持离线量化。 | ||
|
||
## KV Cache量化前提 | ||
|
||
- **依赖** | ||
|
||
```shell | ||
torch==2.1.0 | ||
torchvision==0.16.0 | ||
torch-npu==2.1.0.post6 | ||
``` | ||
|
||
- **工具** | ||
|
||
```shell | ||
amct_pytorch==0.22.2(Ascend-cann-amct_8.0.RC2) | ||
``` | ||
|
||
## KV Cache量化示例 | ||
|
||
在当前目录执行如下命令,得到量化因子记录文件,用户根据实际情况修改示例程序中的model_path(VL模型需要用其语言模型的权重)和dataset_path,并根据模型结构修改quant_layers。 | ||
|
||
```python | ||
python3 ascend_scales_offsets.py | ||
``` | ||
|
||
推理成功后,在当前目录会生成量化日志文件./amct_log/amct_pytorch.log和./outputs文件夹,该文件夹内包含以下内容: | ||
|
||
- **config.json**:量化配置文件,描述了如何对模型中的每一层进行量化。 | ||
- **record.txt**:量化因子记录文件。 | ||
|
||
用户在使用lmdeploy时,通过环境变量ASCEND_QUANT_RECORD_FILE指定量化因子路径,并通过参数quant_policy=8,即可使用量化因子记录文件完成推理。 | ||
示例代码如下: | ||
|
||
```python | ||
import lmdeploy | ||
from lmdeploy import PytorchEngineConfig | ||
if __name__ == "__main__": | ||
pipe = lmdeploy.pipeline("/path_to_model", | ||
backend_config = PytorchEngineConfig(tp=1, | ||
cache_max_entry_count=0.4, device_type="ascend", | ||
eager_mode=True, quant_policy=8)) | ||
question = ["Shanghai is", "Please introduce China", "How are you?"] | ||
response = pipe(question, request_output_len=256, do_preprocess=False) | ||
for idx, r in enumerate(response): | ||
print(f"Q: {question[idx]}") | ||
print(f"A: {r.text}") | ||
print() | ||
``` |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否可以添加默认值:
"k_scales_zeros": tuple()
"v_scales_zeros": tuple()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加