From fd5db582b61414c9195941ff06d235d819f93d7c Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Fri, 8 Nov 2024 11:03:48 +0000 Subject: [PATCH 1/7] [ascend]feat: support kv int8 quant --- dlinfer/ops/llm.py | 36 ++++++++++++++++++++++++-- dlinfer/vendor/ascend/torch_npu_ops.py | 32 ++++++++++++++++++----- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/dlinfer/ops/llm.py b/dlinfer/ops/llm.py index 42955c1..10ea569 100644 --- a/dlinfer/ops/llm.py +++ b/dlinfer/ops/llm.py @@ -151,14 +151,25 @@ def prefill_attention( ) -@register_custom_op("dlinfer::fill_kv_cache", ["key_cache", "value_cache"]) +@register_custom_op( + "dlinfer::fill_kv_cache", + ["key_cache", "value_cache"], + default_value={ + "k_scales_zeros": None, + "v_scales_zeros": None, + "quant_bits": 0, + }, +) def fill_kv_cache( key: Tensor, value: Tensor, key_cache: Tensor, value_cache: Tensor, kv_indices: Tensor, -) -> Tuple[Tensor, Tensor]: + k_scales_zeros: Sequence[Optional[Tensor]], + v_scales_zeros: Sequence[Optional[Tensor]], + quant_bits: int, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ Fills the key-value cache with the provided key and value tensors. @@ -180,6 +191,9 @@ def fill_kv_cache( key_cache, value_cache, kv_indices, + k_scales_zeros, + v_scales_zeros, + quant_bits, ) @@ -190,6 +204,9 @@ def fill_kv_cache( "softmax_scale": None, "alibi_slopes": None, "attn_output": None, + "kv_scales": None, + "kv_zeros": None, + "quant_bits": 0, }, ) def paged_decode_attention( @@ -205,6 +222,9 @@ def paged_decode_attention( softmax_scale: Optional[float], alibi_slopes: Optional[Sequence[float]], attn_output: Optional[Tensor], + kv_scales: Tensor, + kv_zeros: Tensor, + quant_bits: int, ) -> Tensor: """ Computes the multi-head attention over the query, key, and value tensors. @@ -241,6 +261,9 @@ def paged_decode_attention( softmax_scale, alibi_slopes, attn_output, + kv_scales, + kv_zeros, + quant_bits, ) @@ -251,6 +274,9 @@ def paged_decode_attention( "softmax_scale": None, "alibi_slopes": None, "attn_output": None, + "kv_scales": None, + "kv_zeros": None, + "quant_bits": 0, }, ) def paged_prefill_attention( @@ -268,6 +294,9 @@ def paged_prefill_attention( softmax_scale: Optional[float], alibi_slopes: Optional[Sequence[float]], attn_output: Optional[Tensor], + kv_scales: Tensor, + kv_zeros: Tensor, + quant_bits: int, ) -> Tensor: """ Computes the multi-head attention over the query, key, and value tensors. @@ -308,6 +337,9 @@ def paged_prefill_attention( softmax_scale, alibi_slopes, attn_output, + kv_scales, + kv_zeros, + quant_bits, ) diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index 8332608..36135ff 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -122,8 +122,11 @@ def fill_kv_cache( key_cache: Tensor, value_cache: Tensor, kv_indices: Tensor, -) -> Tuple[Tensor, Tensor]: - head, dim = key.shape[1:] + k_scales_zeros: Sequence[Optional[Tensor]], + v_scales_zeros: Sequence[Optional[Tensor]], + quant_bits: int, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + _, head, dim = key.shape block_num, block_size = key_cache.shape[:2] block_total = block_num * block_size @@ -132,6 +135,17 @@ def fill_kv_cache( value = value.contiguous() kv_indices = kv_indices.view(-1, 1) + if quant_bits == 8: + + def quant_int8(x, x_scale, x_offset): + quantized = ( + ((x / x_scale) - x_offset).round().clamp(-128, 127).to(torch.int8) + ) + return quantized + + key = quant_int8(key, k_scales_zeros[0], k_scales_zeros[1]) + value = quant_int8(value, v_scales_zeros[0], v_scales_zeros[1]) + key_cache_reshaped = key_cache.view(block_total, head, dim) value_cache_reshaped = value_cache.view(block_total, head, dim) torch.ops.npu.npu_scatter_nd_update_(key_cache_reshaped, kv_indices, key) @@ -167,6 +181,9 @@ def paged_decode_attention( softmax_scale: Optional[float], alibi_slopes: Optional[Sequence[float]], attn_output: Optional[Tensor], + kv_scales: Tensor, + kv_zeros: Tensor, + quant_bits: int, ) -> Tensor: if alibi_slopes is not None: raise RuntimeError( @@ -188,8 +205,8 @@ def paged_decode_attention( padding_mask=None, atten_mask=None, actual_seq_lengths=kv_seq_len.tolist(), - antiquant_scale=None, - antiquant_offset=None, + antiquant_scale=kv_scales, + antiquant_offset=kv_zeros, block_table=block_table, dequant_scale1=None, quant_scale1=None, @@ -222,6 +239,9 @@ def paged_prefill_attention( softmax_scale: Optional[float], alibi_slopes: Optional[Sequence[float]], attn_output: Optional[Tensor], + kv_scales: Tensor, + kv_zeros: Tensor, + quant_bits: int, ) -> Tensor: if alibi_slopes is not None: raise RuntimeError( @@ -245,8 +265,8 @@ def paged_prefill_attention( padding_mask=None, atten_mask=attn_mask[0], actual_seq_lengths=kv_seq_len_list, - antiquant_scale=None, - antiquant_offset=None, + antiquant_scale=kv_scales, + antiquant_offset=kv_zeros, block_table=block_table, dequant_scale1=None, quant_scale1=None, From d00aed35664a284ed7cc00db364f1375b55054ab Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Mon, 11 Nov 2024 06:57:06 +0000 Subject: [PATCH 2/7] update doc --- .../lmdeploy_ext/quants/ascend_kv.md | 39 +++++ .../lmdeploy_ext/quants/ascend_kv.py | 144 ++++++++++++++++++ 2 files changed, 183 insertions(+) create mode 100644 dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md create mode 100644 dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py diff --git a/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md b/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md new file mode 100644 index 0000000..d918f7d --- /dev/null +++ b/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md @@ -0,0 +1,39 @@ +# KV Cache量化 +目前在华为Atlas 800T A2设备,由于算子功能限制,在算子模式下,仅支持离线量化。 +## KV Cache量化前提 +- 依赖 +``` +torch==2.1.0 +torchvision==0.16.0 +torch-npu==2.1.0.post6 +``` +- 工具 +``` +amct_pytorch==0.22.2(Ascend-cann-amct_8.0.RC2) +``` +## KV Cache量化示例 +在当前目录执行如下命令,得到量化因子记录文件,用户根据实际情况修改示例程序中的model_path和dataset_path,并根据模型结构修改quant_layers。 +``` +python3 ascend_kv.py +``` +推理成功后,在当前目录会生成量化日志文件./amct_log/amct_pytorch.log和./outputs文件夹,该文件夹内包含以下内容: + +- config.json:量化配置文件,描述了如何对模型中的每一层进行量化。 +- record.txt:量化因子记录文件。 + +用户在使用lmdeploy时,通过环境变量ASCEND_QUANT_RECORD_FILE指定量化因子路径,并通过参数quant_policy=8,即可使用量化因子记录文件完成推理。 +示例代码如下: +``` +import lmdeploy +from lmdeploy import PytorchEngineConfig +if __name__ == "__main__": + pipe = lmdeploy.pipeline("/path_to_model", + backend_config = PytorchEngineConfig(tp=1, + cache_max_entry_count=0.4, device_type="ascend", eager_mode=True, quant_policy=8)) + question = ["Shanghai is", "Please introduce China", "How are you?"] + response = pipe(question, request_output_len=256, do_preprocess=False) + for idx, r in enumerate(response): + print(f"Q: {question[idx]}") + print(f"A: {r.text}") + print() +``` diff --git a/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py b/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py new file mode 100644 index 0000000..96a67e8 --- /dev/null +++ b/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py @@ -0,0 +1,144 @@ +import os +import time + +import tqdm +import torch +import torch.nn as nn +import dlinfer +import transformers +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +from accelerate import infer_auto_device_map, dispatch_model +from accelerate.utils.modeling import get_balanced_memory +from datasets import load_dataset + +import amct_pytorch as amct + + +def get_llama2(model_path, seqlen=2048): + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16) + + model.seqlen = seqlen + return model + +def build_model_and_enc(model, model_path, gpu_num): + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + if "mpt" in config.__class__.__name__.lower(): + enc = AutoTokenizer.from_pretrained( + config.tokenizer_name, trust_remote_code=True + ) + else: + enc = AutoTokenizer.from_pretrained( + model_path, use_fast=False, trust_remote_code=True + ) + + # Move the model to GPU (as much as possible) for LM evaluation + # max_memory = ['0:16GiB', '1:16GiB','2:16GiB', 'cpu:30GiB'], '0' means the first GPU that you specify. + # I don't recommend use 16GiB, we need to reserve some space for other tensors during calculation + # please see the recommand memeory allocation in the Word file + # Adjust the max_size accroding to the real situation + # a clever way: + max_memory = [] + for i in range(gpu_num): + max_memory.append(f'{i}:12GiB') + max_memory.append('cpu:80GiB') + print('Max_memory allocation: \n', max_memory) + + max_memory = [v.split(":") for v in (max_memory or [])] + max_memory = {(int(k) if k.isdigit() else k): v for k, v in max_memory} + kwargs = { + "max_memory": get_balanced_memory( + model, max_memory if len(max_memory) > 0 else None + ) + } + model.tie_weights() + device_map = infer_auto_device_map( + model, + # TODO: can we remove this? + no_split_module_classes=[ + "OPTDecoderLayer", + "LlamaDecoderLayer", + "BloomBlock", + "MPTBlock", + "DecoderLayer", + ], + **kwargs, + ) + model = dispatch_model(model, device_map=device_map, + offload_dir=os.path.join(model_path, 'offload_dir')) + + return model, enc + + +def get_loaders(dataset_path: str, enc, seqlen): + print('Loading dataset c4/realnewslike') + testenc = load_dataset( + 'json', + data_files={'validation':dataset_path}, + split = 'validation' + ) + import pdb;pdb.set_trace() + testenc = enc(' '.join(testenc[:1100]['text']), return_tensors='pt') + testenc = testenc.input_ids[:, :(256 * seqlen)] + + class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + testenc = TokenizerWrapper(testenc) + + return testenc + + +def main(): + # Load model + model_path = '/data2/share_data/internlm_model_data/internlm2_5-7b-chat' + model = get_llama2(model_path) + model = model.eval() + gpus = os.getenv('CUDA_VISIBLE_DEVICES') + if gpus == '' or gpus is None: + gpu_num = 0 + else: + gpu_num = len(gpus.split(',')) + model, enc = build_model_and_enc(model, model_path, gpu_num) + model.seqlen = 2048 + + # Load dataset + dataset_path = './c4/c4-train.00000-of-00512.json' + testenc = get_loaders(dataset_path=dataset_path, + enc=enc, + seqlen=model.seqlen) + + import pdb;pdb.set_trace() + testenc = testenc.input_ids.to(model.device) + + config_file = './outputs/config.json' + amct.create_quant_cali_config(config_file=config_file, + model=model, + quant_layers={'kv_cache_quant_laye': + [f'model.layers.{i}.attention.wqkv' for i in range(32)]}, + config_defination=None) + + record_file = './outputs/record.txt' + quant_cali_model = amct.create_quant_cali_model(config_file=config_file, + record_file=record_file, + model=model).npu() + + # Do inference to get quantize factors + batch_num = 3 + test_start_time = time.time() + for i in tqdm.tqdm(range(batch_num), desc="getting quantize factors..."): + batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(model.device) + with torch.no_grad(): + quant_cali_model(batch) + test_end_time = time.time() + total_time = test_end_time - test_start_time + print('Get quantize factors taken: ', total_time // 60, 'min ', total_time%60, 's' ) + + +if __name__ == '__main__': + main() From 5eb089456d926e2a12860c4cb4e70eef00408812 Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Mon, 11 Nov 2024 08:58:42 +0000 Subject: [PATCH 3/7] format code --- .../lmdeploy_ext/quants/ascend_kv.md | 31 ++++-- .../lmdeploy_ext/quants/ascend_kv.py | 96 +++++++++++-------- 2 files changed, 78 insertions(+), 49 deletions(-) diff --git a/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md b/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md index d918f7d..b670852 100644 --- a/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md +++ b/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md @@ -1,35 +1,48 @@ + # KV Cache量化 + 目前在华为Atlas 800T A2设备,由于算子功能限制,在算子模式下,仅支持离线量化。 + ## KV Cache量化前提 -- 依赖 -``` + +- **依赖** + +```shell torch==2.1.0 torchvision==0.16.0 torch-npu==2.1.0.post6 ``` -- 工具 -``` + +- **工具** + +```shell amct_pytorch==0.22.2(Ascend-cann-amct_8.0.RC2) ``` + ## KV Cache量化示例 + 在当前目录执行如下命令,得到量化因子记录文件,用户根据实际情况修改示例程序中的model_path和dataset_path,并根据模型结构修改quant_layers。 -``` + +```shell python3 ascend_kv.py ``` + 推理成功后,在当前目录会生成量化日志文件./amct_log/amct_pytorch.log和./outputs文件夹,该文件夹内包含以下内容: -- config.json:量化配置文件,描述了如何对模型中的每一层进行量化。 -- record.txt:量化因子记录文件。 +- **config.json**:量化配置文件,描述了如何对模型中的每一层进行量化。 +- **record.txt**:量化因子记录文件。 用户在使用lmdeploy时,通过环境变量ASCEND_QUANT_RECORD_FILE指定量化因子路径,并通过参数quant_policy=8,即可使用量化因子记录文件完成推理。 示例代码如下: -``` + +```python import lmdeploy from lmdeploy import PytorchEngineConfig if __name__ == "__main__": pipe = lmdeploy.pipeline("/path_to_model", backend_config = PytorchEngineConfig(tp=1, - cache_max_entry_count=0.4, device_type="ascend", eager_mode=True, quant_policy=8)) + cache_max_entry_count=0.4, device_type="ascend", + eager_mode=True, quant_policy=8)) question = ["Shanghai is", "Please introduce China", "How are you?"] response = pipe(question, request_output_len=256, do_preprocess=False) for idx, r in enumerate(response): diff --git a/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py b/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py index 96a67e8..50794ea 100644 --- a/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py +++ b/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py @@ -21,11 +21,14 @@ def skip(*args, **kwargs): torch.nn.init.kaiming_uniform_ = skip torch.nn.init.uniform_ = skip torch.nn.init.normal_ = skip - model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16) + model = AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, torch_dtype=torch.float16 + ) model.seqlen = seqlen return model + def build_model_and_enc(model, model_path, gpu_num): config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) if "mpt" in config.__class__.__name__.lower(): @@ -45,9 +48,9 @@ def build_model_and_enc(model, model_path, gpu_num): # a clever way: max_memory = [] for i in range(gpu_num): - max_memory.append(f'{i}:12GiB') - max_memory.append('cpu:80GiB') - print('Max_memory allocation: \n', max_memory) + max_memory.append(f"{i}:12GiB") + max_memory.append("cpu:80GiB") + print("Max_memory allocation: \n", max_memory) max_memory = [v.split(":") for v in (max_memory or [])] max_memory = {(int(k) if k.isdigit() else k): v for k, v in max_memory} @@ -69,76 +72,89 @@ def build_model_and_enc(model, model_path, gpu_num): ], **kwargs, ) - model = dispatch_model(model, device_map=device_map, - offload_dir=os.path.join(model_path, 'offload_dir')) + model = dispatch_model( + model, + device_map=device_map, + offload_dir=os.path.join(model_path, "offload_dir"), + ) return model, enc def get_loaders(dataset_path: str, enc, seqlen): - print('Loading dataset c4/realnewslike') + print("Loading dataset c4/realnewslike") testenc = load_dataset( - 'json', - data_files={'validation':dataset_path}, - split = 'validation' - ) - import pdb;pdb.set_trace() - testenc = enc(' '.join(testenc[:1100]['text']), return_tensors='pt') - testenc = testenc.input_ids[:, :(256 * seqlen)] + "json", data_files={"validation": dataset_path}, split="validation" + ) + import pdb + + pdb.set_trace() + testenc = enc(" ".join(testenc[:1100]["text"]), return_tensors="pt") + testenc = testenc.input_ids[:, : (256 * seqlen)] class TokenizerWrapper: def __init__(self, input_ids): self.input_ids = input_ids + testenc = TokenizerWrapper(testenc) - + return testenc def main(): # Load model - model_path = '/data2/share_data/internlm_model_data/internlm2_5-7b-chat' + model_path = "/data2/share_data/internlm_model_data/internlm2_5-7b-chat" model = get_llama2(model_path) model = model.eval() - gpus = os.getenv('CUDA_VISIBLE_DEVICES') - if gpus == '' or gpus is None: + gpus = os.getenv("CUDA_VISIBLE_DEVICES") + if gpus == "" or gpus is None: gpu_num = 0 else: - gpu_num = len(gpus.split(',')) + gpu_num = len(gpus.split(",")) model, enc = build_model_and_enc(model, model_path, gpu_num) model.seqlen = 2048 # Load dataset - dataset_path = './c4/c4-train.00000-of-00512.json' - testenc = get_loaders(dataset_path=dataset_path, - enc=enc, - seqlen=model.seqlen) - - import pdb;pdb.set_trace() + dataset_path = "./c4/c4-train.00000-of-00512.json" + testenc = get_loaders(dataset_path=dataset_path, enc=enc, seqlen=model.seqlen) + + import pdb + + pdb.set_trace() testenc = testenc.input_ids.to(model.device) - - config_file = './outputs/config.json' - amct.create_quant_cali_config(config_file=config_file, - model=model, - quant_layers={'kv_cache_quant_laye': - [f'model.layers.{i}.attention.wqkv' for i in range(32)]}, - config_defination=None) - - record_file = './outputs/record.txt' - quant_cali_model = amct.create_quant_cali_model(config_file=config_file, - record_file=record_file, - model=model).npu() + + config_file = "./outputs/config.json" + amct.create_quant_cali_config( + config_file=config_file, + model=model, + quant_layers={ + "kv_cache_quant_laye": [ + f"model.layers.{i}.attention.wqkv" for i in range(32) + ] + }, + config_defination=None, + ) + + record_file = "./outputs/record.txt" + quant_cali_model = amct.create_quant_cali_model( + config_file=config_file, record_file=record_file, model=model + ).npu() # Do inference to get quantize factors batch_num = 3 test_start_time = time.time() for i in tqdm.tqdm(range(batch_num), desc="getting quantize factors..."): - batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(model.device) + batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to( + model.device + ) with torch.no_grad(): quant_cali_model(batch) test_end_time = time.time() total_time = test_end_time - test_start_time - print('Get quantize factors taken: ', total_time // 60, 'min ', total_time%60, 's' ) + print( + "Get quantize factors taken: ", total_time // 60, "min ", total_time % 60, "s" + ) -if __name__ == '__main__': +if __name__ == "__main__": main() From 8bbec892372337a9c9e5113397c7a0a387c8b3f9 Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Tue, 12 Nov 2024 08:05:28 +0000 Subject: [PATCH 4/7] update code --- .../lmdeploy_ext/quants/ascend_kv.md | 2 +- .../lmdeploy_ext/quants/ascend_kv.py | 50 ++++++++++--------- dlinfer/ops/llm.py | 8 +-- 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md b/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md index b670852..fa5bf1b 100644 --- a/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md +++ b/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md @@ -24,7 +24,7 @@ amct_pytorch==0.22.2(Ascend-cann-amct_8.0.RC2) 在当前目录执行如下命令,得到量化因子记录文件,用户根据实际情况修改示例程序中的model_path和dataset_path,并根据模型结构修改quant_layers。 ```shell -python3 ascend_kv.py +VISIBLE_DEVICES=0,1 python3 ascend_kv.py ``` 推理成功后,在当前目录会生成量化日志文件./amct_log/amct_pytorch.log和./outputs文件夹,该文件夹内包含以下内容: diff --git a/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py b/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py index 50794ea..49ec46e 100644 --- a/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py +++ b/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py @@ -1,11 +1,10 @@ import os import time +import json import tqdm import torch -import torch.nn as nn import dlinfer -import transformers from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig from accelerate import infer_auto_device_map, dispatch_model from accelerate.utils.modeling import get_balanced_memory @@ -29,6 +28,18 @@ def skip(*args, **kwargs): return model +def get_layer_num(model_path): + config_file = f"{model_path}/config.json" + with open(config_file, "r") as json_file: + model_config = json.load(json_file) + return model_config["num_hidden_layers"] + + +def get_gpu_memory(max_entry_count=0.8): + gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) + return int(gpu_memory * max_entry_count) + + def build_model_and_enc(model, model_path, gpu_num): config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) if "mpt" in config.__class__.__name__.lower(): @@ -40,15 +51,10 @@ def build_model_and_enc(model, model_path, gpu_num): model_path, use_fast=False, trust_remote_code=True ) - # Move the model to GPU (as much as possible) for LM evaluation - # max_memory = ['0:16GiB', '1:16GiB','2:16GiB', 'cpu:30GiB'], '0' means the first GPU that you specify. - # I don't recommend use 16GiB, we need to reserve some space for other tensors during calculation - # please see the recommand memeory allocation in the Word file - # Adjust the max_size accroding to the real situation - # a clever way: max_memory = [] + gpu_memory = get_gpu_memory() for i in range(gpu_num): - max_memory.append(f"{i}:12GiB") + max_memory.append(f"{i}:{gpu_memory}GiB") max_memory.append("cpu:80GiB") print("Max_memory allocation: \n", max_memory) @@ -82,13 +88,10 @@ def build_model_and_enc(model, model_path, gpu_num): def get_loaders(dataset_path: str, enc, seqlen): - print("Loading dataset c4/realnewslike") + print("Loading dataset...") testenc = load_dataset( "json", data_files={"validation": dataset_path}, split="validation" ) - import pdb - - pdb.set_trace() testenc = enc(" ".join(testenc[:1100]["text"]), return_tensors="pt") testenc = testenc.input_ids[:, : (256 * seqlen)] @@ -106,7 +109,7 @@ def main(): model_path = "/data2/share_data/internlm_model_data/internlm2_5-7b-chat" model = get_llama2(model_path) model = model.eval() - gpus = os.getenv("CUDA_VISIBLE_DEVICES") + gpus = os.getenv("VISIBLE_DEVICES") if gpus == "" or gpus is None: gpu_num = 0 else: @@ -118,27 +121,28 @@ def main(): dataset_path = "./c4/c4-train.00000-of-00512.json" testenc = get_loaders(dataset_path=dataset_path, enc=enc, seqlen=model.seqlen) - import pdb - - pdb.set_trace() testenc = testenc.input_ids.to(model.device) + layer_num = get_layer_num(model_path) + config_file = "./outputs/config.json" + internlm_layers = [f"model.layers.{i}.attention.wqkv" for i in range(layer_num)] + llama_layers = [ + f"model.layers.{i}.self_attn.{proj}" + for i in range(layer_num) + for proj in ["k_proj", "v_proj"] + ] amct.create_quant_cali_config( config_file=config_file, model=model, - quant_layers={ - "kv_cache_quant_laye": [ - f"model.layers.{i}.attention.wqkv" for i in range(32) - ] - }, + quant_layers={"kv_cache_quant_layers": internlm_layers}, config_defination=None, ) record_file = "./outputs/record.txt" quant_cali_model = amct.create_quant_cali_model( config_file=config_file, record_file=record_file, model=model - ).npu() + ) # Do inference to get quantize factors batch_num = 3 diff --git a/dlinfer/ops/llm.py b/dlinfer/ops/llm.py index 10ea569..3d146bc 100644 --- a/dlinfer/ops/llm.py +++ b/dlinfer/ops/llm.py @@ -155,8 +155,8 @@ def prefill_attention( "dlinfer::fill_kv_cache", ["key_cache", "value_cache"], default_value={ - "k_scales_zeros": None, - "v_scales_zeros": None, + "k_scales_zeros": [], + "v_scales_zeros": [], "quant_bits": 0, }, ) @@ -166,8 +166,8 @@ def fill_kv_cache( key_cache: Tensor, value_cache: Tensor, kv_indices: Tensor, - k_scales_zeros: Sequence[Optional[Tensor]], - v_scales_zeros: Sequence[Optional[Tensor]], + k_scales_zeros: Sequence[Tensor], + v_scales_zeros: Sequence[Tensor], quant_bits: int, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ From 82a8a01313526c5c7423f40fd0ef72eddaed63b1 Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Wed, 13 Nov 2024 06:13:36 +0000 Subject: [PATCH 5/7] update params --- .../graph/dicp/vendor/AtbGraph/conversion.py | 18 +++++++++++++++++- dlinfer/ops/llm.py | 12 +++++------- dlinfer/vendor/ascend/torch_npu_ops.py | 12 ++++++------ dlinfer/vendor/maca/maca_ops.py | 9 +++++++++ 4 files changed, 37 insertions(+), 14 deletions(-) diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/conversion.py b/dlinfer/graph/dicp/vendor/AtbGraph/conversion.py index cc323ee..35661cc 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/conversion.py +++ b/dlinfer/graph/dicp/vendor/AtbGraph/conversion.py @@ -126,7 +126,17 @@ def apply_rotary_pos_emb(self, q, k, cos, sin, q_out, k_out): return out @register_conversion("torch.ops.dlinfer.fill_kv_cache.default") - def fill_kv_cache(self, key, value, key_cache, value_cache, kv_indices): + def fill_kv_cache( + self, + key, + value, + key_cache, + value_cache, + kv_indices, + k_scales_zeros, + v_scales_zeros, + quant_bits, + ): key_cache_shape = key_cache.node.meta["val"].shape key_shape = key.node.meta["val"].shape key_cache_reshaped = self.get_proxy( @@ -171,6 +181,9 @@ def paged_attention_decode( softmax_scale, alibi_slopes, attn_output, + kv_scales, + kv_zeros, + quant_bits, ): q_head_num = num_q_heads kv_head_num = num_kv_heads @@ -370,6 +383,9 @@ def prefill_attention( block_size, mask, is_unpaged_prefill, + kv_scales, + kv_zeros, + quant_bits, ): # k_cache = self.get_proxy(atb_op.View, (k_cache, [-1, block_size, num_kv_heads, kv_head_size])) # v_cache = self.get_proxy(atb_op.View, (v_cache, [-1, block_size, num_kv_heads, kv_head_size])) diff --git a/dlinfer/ops/llm.py b/dlinfer/ops/llm.py index 3d146bc..8138709 100644 --- a/dlinfer/ops/llm.py +++ b/dlinfer/ops/llm.py @@ -155,8 +155,6 @@ def prefill_attention( "dlinfer::fill_kv_cache", ["key_cache", "value_cache"], default_value={ - "k_scales_zeros": [], - "v_scales_zeros": [], "quant_bits": 0, }, ) @@ -166,8 +164,8 @@ def fill_kv_cache( key_cache: Tensor, value_cache: Tensor, kv_indices: Tensor, - k_scales_zeros: Sequence[Tensor], - v_scales_zeros: Sequence[Tensor], + k_scales_zeros: Sequence[Optional[Tensor]], + v_scales_zeros: Sequence[Optional[Tensor]], quant_bits: int, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ @@ -222,9 +220,9 @@ def paged_decode_attention( softmax_scale: Optional[float], alibi_slopes: Optional[Sequence[float]], attn_output: Optional[Tensor], - kv_scales: Tensor, - kv_zeros: Tensor, - quant_bits: int, + kv_scales: Optional[Tensor], + kv_zeros: Optional[Tensor], + quant_bits: Optional[int], ) -> Tensor: """ Computes the multi-head attention over the query, key, and value tensors. diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index 36135ff..ac5b628 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -181,9 +181,9 @@ def paged_decode_attention( softmax_scale: Optional[float], alibi_slopes: Optional[Sequence[float]], attn_output: Optional[Tensor], - kv_scales: Tensor, - kv_zeros: Tensor, - quant_bits: int, + kv_scales: Optional[Tensor], + kv_zeros: Optional[Tensor], + quant_bits: Optional[int], ) -> Tensor: if alibi_slopes is not None: raise RuntimeError( @@ -239,9 +239,9 @@ def paged_prefill_attention( softmax_scale: Optional[float], alibi_slopes: Optional[Sequence[float]], attn_output: Optional[Tensor], - kv_scales: Tensor, - kv_zeros: Tensor, - quant_bits: int, + kv_scales: Optional[Tensor], + kv_zeros: Optional[Tensor], + quant_bits: Optional[int], ) -> Tensor: if alibi_slopes is not None: raise RuntimeError( diff --git a/dlinfer/vendor/maca/maca_ops.py b/dlinfer/vendor/maca/maca_ops.py index a32c407..f1a2f55 100644 --- a/dlinfer/vendor/maca/maca_ops.py +++ b/dlinfer/vendor/maca/maca_ops.py @@ -182,6 +182,9 @@ def fill_kv_cache( key_cache: Tensor, value_cache: Tensor, kv_indices: Tensor, + k_scales_zeros: Sequence[Optional[Tensor]], + v_scales_zeros: Sequence[Optional[Tensor]], + quant_bits: int, ) -> Tuple[Tensor, Tensor]: kv_indices = kv_indices.squeeze(-1) maca_ext_ops.reshape_and_cache_new( @@ -204,6 +207,9 @@ def paged_decode_attention( softmax_scale: Optional[float], alibi_slopes: Optional[Sequence[float]], attn_output: Optional[Tensor], + kv_scales: Optional[Tensor], + kv_zeros: Optional[Tensor], + quant_bits: Optional[int], ) -> Tensor: if alibi_slopes is not None: raise RuntimeError("paged_decode_attention does not support alibi_slopes yet") @@ -269,6 +275,9 @@ def paged_prefill_attention( softmax_scale: Optional[float], alibi_slopes: Optional[Sequence[float]], attn_output: Optional[Tensor], + kv_scales: Optional[Tensor], + kv_zeros: Optional[Tensor], + quant_bits: Optional[int], ) -> Tensor: dim = query.size(-1) batch_size = block_table.size(0) From b8fc1b34c8c2ddb9486fb0df977a455e0a07be50 Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Wed, 13 Nov 2024 06:13:59 +0000 Subject: [PATCH 6/7] test ascend_kv_int8 --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f0ba727..7a32c75 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -12,7 +12,7 @@ on: env: CI_PATH: '/data2/wugeshui/GitHub/${{ github.repository }}/${GITHUB_RUN_NUMBER}' LMDEPLOY_PATH: '/data2/wugeshui/GitHub/lmdeploy' - LMDEPLOY_COMMIT_OR_BRANCH: 'main' + LMDEPLOY_COMMIT_OR_BRANCH: 'ascend_kv_int8' REPORT_DIR: /data2/wugeshui/GitHub/ci_log/test_reports concurrency: From c200c6e9400a99e77b59662b17574840ea47134a Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Wed, 13 Nov 2024 08:08:34 +0000 Subject: [PATCH 7/7] update docs --- dlinfer/ops/llm.py | 13 +++++++++- dlinfer/vendor/ascend/torch_npu_ops.py | 2 +- .../quant/ascend_kv_quant.md | 6 ++--- .../quant/ascend_scales_offsets.py | 25 +++---------------- 4 files changed, 20 insertions(+), 26 deletions(-) rename dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md => docs/quant/ascend_kv_quant.md (87%) rename dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py => docs/quant/ascend_scales_offsets.py (89%) diff --git a/dlinfer/ops/llm.py b/dlinfer/ops/llm.py index 8138709..7d7cb9c 100644 --- a/dlinfer/ops/llm.py +++ b/dlinfer/ops/llm.py @@ -155,6 +155,8 @@ def prefill_attention( "dlinfer::fill_kv_cache", ["key_cache", "value_cache"], default_value={ + "k_scales_zeros": tuple(), + "v_scales_zeros": tuple(), "quant_bits": 0, }, ) @@ -167,7 +169,7 @@ def fill_kv_cache( k_scales_zeros: Sequence[Optional[Tensor]], v_scales_zeros: Sequence[Optional[Tensor]], quant_bits: int, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> Tuple[Tensor, Tensor]: """ Fills the key-value cache with the provided key and value tensors. @@ -177,6 +179,9 @@ def fill_kv_cache( key_cache (Tensor): The existing key cache tensor. value_cache (Tensor): The existing value cache tensor. kv_indices (Tensor): The indices specifying where to store the key and value in the cache. + k_scales_zeros (Sequence[Optional[Tensor]]): The scales and zeros used to quantify key. + v_scales_zeros (Sequence[Optional[Tensor]]): The scales and zeros used to quantify value. + quant_bits (int): The bits which k/v is quantized into. Returns: Tuple[Tensor, Tensor]: @@ -242,6 +247,9 @@ def paged_decode_attention( softmax_scale (Optional[float]): The scale factor to apply to the attention logits before the softmax. alibi_slopes (Optional[Sequence[float]]): The slopes for the ALiBi attention bias, one for each head. attn_output (Optional[Tensor]): The computed attention output tensor. + kv_scales (Optional[Tensor]): The quantization factors for key and value. + kv_zeros (Optional[Tensor]): The quantization offset for key and value. + quant_bits (Optional[int]): The bits which k/v is quantized into. Returns: Tensor: The computed attention output tensor, alias of attn_output. @@ -316,6 +324,9 @@ def paged_prefill_attention( softmax_scale (Optional[float]): The scale factor to apply to the attention logits before the softmax. alibi_slopes (Optional[Sequence[float]]): The slopes for the ALiBi attention bias, one for each head. attn_output (Optional[Tensor]): The computed attention output tensor. + kv_scales (Optional[Tensor]): The quantization factors for key and value. + kv_zeros (Optional[Tensor]): The quantization offset for key and value. + quant_bits (Optional[int]): The bits which k/v is quantized into. Returns: Tensor: The computed attention output tensor, alias of attn_output. diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index ac5b628..c1a7d56 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -125,7 +125,7 @@ def fill_kv_cache( k_scales_zeros: Sequence[Optional[Tensor]], v_scales_zeros: Sequence[Optional[Tensor]], quant_bits: int, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> Tuple[Tensor, Tensor]: _, head, dim = key.shape block_num, block_size = key_cache.shape[:2] block_total = block_num * block_size diff --git a/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md b/docs/quant/ascend_kv_quant.md similarity index 87% rename from dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md rename to docs/quant/ascend_kv_quant.md index fa5bf1b..d637f9c 100644 --- a/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.md +++ b/docs/quant/ascend_kv_quant.md @@ -21,10 +21,10 @@ amct_pytorch==0.22.2(Ascend-cann-amct_8.0.RC2) ## KV Cache量化示例 -在当前目录执行如下命令,得到量化因子记录文件,用户根据实际情况修改示例程序中的model_path和dataset_path,并根据模型结构修改quant_layers。 +在当前目录执行如下命令,得到量化因子记录文件,用户根据实际情况修改示例程序中的model_path(VL模型需要用其语言模型的权重)和dataset_path,并根据模型结构修改quant_layers。 -```shell -VISIBLE_DEVICES=0,1 python3 ascend_kv.py +```python +python3 ascend_scales_offsets.py ``` 推理成功后,在当前目录会生成量化日志文件./amct_log/amct_pytorch.log和./outputs文件夹,该文件夹内包含以下内容: diff --git a/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py b/docs/quant/ascend_scales_offsets.py similarity index 89% rename from dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py rename to docs/quant/ascend_scales_offsets.py index 49ec46e..42efa28 100644 --- a/dlinfer/framework/lmdeploy_ext/quants/ascend_kv.py +++ b/docs/quant/ascend_scales_offsets.py @@ -1,7 +1,5 @@ import os -import time import json - import tqdm import torch import dlinfer @@ -9,13 +7,11 @@ from accelerate import infer_auto_device_map, dispatch_model from accelerate.utils.modeling import get_balanced_memory from datasets import load_dataset - import amct_pytorch as amct -def get_llama2(model_path, seqlen=2048): - def skip(*args, **kwargs): - pass +def get_model(model_path, seqlen=2048): + def skip(*args, **kwargs): ... torch.nn.init.kaiming_uniform_ = skip torch.nn.init.uniform_ = skip @@ -23,7 +19,6 @@ def skip(*args, **kwargs): model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch.float16 ) - model.seqlen = seqlen return model @@ -83,7 +78,6 @@ def build_model_and_enc(model, model_path, gpu_num): device_map=device_map, offload_dir=os.path.join(model_path, "offload_dir"), ) - return model, enc @@ -100,20 +94,15 @@ def __init__(self, input_ids): self.input_ids = input_ids testenc = TokenizerWrapper(testenc) - return testenc def main(): # Load model model_path = "/data2/share_data/internlm_model_data/internlm2_5-7b-chat" - model = get_llama2(model_path) + model = get_model(model_path) model = model.eval() - gpus = os.getenv("VISIBLE_DEVICES") - if gpus == "" or gpus is None: - gpu_num = 0 - else: - gpu_num = len(gpus.split(",")) + gpu_num = torch.cuda.device_count() model, enc = build_model_and_enc(model, model_path, gpu_num) model.seqlen = 2048 @@ -146,18 +135,12 @@ def main(): # Do inference to get quantize factors batch_num = 3 - test_start_time = time.time() for i in tqdm.tqdm(range(batch_num), desc="getting quantize factors..."): batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to( model.device ) with torch.no_grad(): quant_cali_model(batch) - test_end_time = time.time() - total_time = test_end_time - test_start_time - print( - "Get quantize factors taken: ", total_time // 60, "min ", total_time % 60, "s" - ) if __name__ == "__main__":