Skip to content

Commit

Permalink
Merge branch 'main' into fix-flash-kernel-sm70
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoqian committed Jan 26, 2024
2 parents c6407db + 6a15f68 commit 374818a
Show file tree
Hide file tree
Showing 28 changed files with 784 additions and 133 deletions.
66 changes: 66 additions & 0 deletions docs/en/advance/pytorch_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,72 @@ MODEL_MAP.update({

With these adjustments, the model is now capable of utilizing multiple GPUs for deploying Large Language Models (LLM). This enables efficient distribution of computations across different devices in a parallelized manner.

## Debug Module

When the output of the model does not meet expectations, we would like to debug a specific module to determine if the added rewrite is correct. `lmdeploy.pytorch` provides some tools to assist with accuracy alignment. Let’s take `LlamaAttention` module as an example.

First, create an instance of the module that we want to debug:

```python
import torch
from transformers import AutoModelForCausalLM

# get module
model_path = 'meta-llama/Llama-2-7b-chat-hf'
dtype = torch.float16
model = AutoModelForCausalLM.from_pretrained(model_path).to(torch.float16).cuda()
self_attn = model.model.layers[0].self_attn
```

Extract the inputs/outputs with `ModuleIOExtractor`.

```python
from lmdeploy.pytorch.tools.make_inputs import ModuleIOExtractor

# extract module input/output
input_ids = torch.tensor([[1, 2, 3, 4, 5]]).cuda()
extractor = ModuleIOExtractor(model, self_attn)
attn_args, attn_kwargs, attn_output = extractor.extract(input_ids)
```

The inputs of rewrite module are different from the inputs of origin module:

1. Module requires some special inputs, which are passed through `StepContext`. We can create one with `make_step_context`.
2. `input_ids`, `hidden_states` should be continuous. We can use `continuous_tensor` to do the process.
3. `past_key_value` should be paged to meet the demand of paged attention.

Based on the reason above, the input should be updated:

```python
from lmdeploy.pytorch.tools.make_inputs import make_step_context
from lmdeploy.pytorch.tools.layout_convert import continuous_tensor

# create patched input/output
context = make_step_context(input_ids,
kv_cache_dtype=dtype)
seq_length = context.seq_length
attn_kwargs['hidden_states'] = continuous_tensor(
attn_kwargs['hidden_states'],
seq_length)
attn_kwargs['past_key_value'] = context.kv_caches[0]
```

Then you can start the rewrite and compare the correctness of the results.

```python
from lmdeploy.pytorch.models import patch

# patch and test
patched_self_attn = patch(self_attn, extra_args=['context'])
patched_output = patched_self_attn.patched_forward(*attn_args,
**attn_kwargs,
context=context)
torch.testing.assert_close(patched_output[0],
continuous_tensor(attn_output[0], seq_length))
```

Adjust the rewrite module until the output can be aligned.

## Appendix

### context info
Expand Down
9 changes: 6 additions & 3 deletions docs/en/build.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ The docker image is `openmmlab/lmdeploy-builder:cuda11.8`. Make sure that docker
In the root directory of the lmdeploy source code, please run the following command:

```shell
cd lmdeploy # the home folder of lmdeploy source code
# the home folder of lmdeploy source code
cd lmdeploy
bash builder/manywheel/build_all_wheel.sh
```

Expand Down Expand Up @@ -67,8 +68,10 @@ Then, follow the steps below to set up the compilation environment:
```
- build and install lmdeploy libraries:
```shell
apt install ninja-build # install ninja
cd lmdeploy # the home folder of lmdeploy
# install ninja
apt install ninja-build
# the home folder of lmdeploy
cd lmdeploy
mkdir build && cd build
sh ../generate.sh
ninja -j$(nproc) && ninja install
Expand Down
50 changes: 16 additions & 34 deletions docs/en/quantization/kv_int8.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,50 +15,26 @@ dequant: f = q * scale + zp

### **Step One**

Convert the Hugging Face model format to the TurboMind inference format to create a workspace directory.

```bash
lmdeploy convert internlm-chat-7b /path/to/internlm-chat-7b
```

If you already have a workspace directory, skip this step.

### **Step Two**

Get the quantization parameters by these two steps:
Get the quantization parameters and save them to the original HF model directory:

```bash
# get minmax
export HF_MODEL=internlm/internlm-chat-7b

lmdeploy lite calibrate \
$HF_MODEL \
--calib-dataset 'ptb' \ # Support c4, ptb, wikitext2, pileval
--calib-samples 128 \ # Number of samples in the calibration set, if the memory is not enough, it can be adjusted appropriately
--calib-seqlen 2048 \ # Length of a single text, if the memory is not enough, you can adjust it appropriately
--work-dir $WORK_DIR \ # Directory for saving quantized statistical parameters and quantized weights in Pytorch format

# get quant parameters
lmdeploy lite kv_qparams \
$WORK_DIR \ # Directory of the last output
workspace/triton_models/weights/ \ # Directory to save the quantization parameters
--num-tp 1 \ # Number of GPUs used for Tensor parallelization, keep it consistent with deploy.py
--calib-dataset 'ptb' \
--calib-samples 128 \
--calib-seqlen 2048 \
--work-dir $HF_MODEL
```

`kv_qparams` will generate fp32 scaling factors in the `weights` directory. The file format is a binary produced by `numpy.tofile`.

You can also first set `turbomind_dir` to a private directory, then copy the scaling factors into `workspace/triton_models/weights/`.

### **Step Three**

Modify `workspace/triton_models/weights/config.ini`:

- Set quant_policy to 4. This means enabling kv_cache int8

### **Step Four**
### **Step Two**

Test the chat performance.
Test the chat performance. Note that setting `--quant-policy 4` would set to KV Cache int8 mode.

```bash
lmdeploy chat turbomind ./workspace
lmdeploy chat turbomind $HF_MODEL --model-format hf --quant-policy 4
```

## GPU Memory Test
Expand Down Expand Up @@ -102,3 +78,9 @@ Below is the result of PTQ quantization of `kCacheKVInt8` method with only 128 r
| Safety | crows_pairs | accuracy | 32.56 | 31.43 | +1.13 |

Note that both `kCacheKVInt8` and `WeightInt4` methods can be enabled at the same time.
Please refer to [w4a16](./w4a16.md) do `WeightInt4` and then
start chat like:

```shell
lmdeploy chat turbomind ./internlm-chat-7b-4bit --model-format awq --quant-policy 4
```
19 changes: 11 additions & 8 deletions docs/en/quantization/w4a16.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@ This article comprises the following sections:
A single command execution is all it takes to quantize the model. The resulting quantized weights are then stored in the $WORK_DIR directory.

```shell
export HF_MODEL=internlm/internlm-chat-7b
export WORK_DIR=internlm/internlm-chat-7b-4bit

lmdeploy lite auto_awq \
$HF_MODEL \ # Model name or path, either model repo name on huggingface hub like 'internlm/internlm-chat-7b', or a model path in local host
--calib-dataset 'ptb' \ # Calibration dataset, supports c4, ptb, wikitext2, pileval
--calib-samples 128 \ # Number of samples in the calibration set, if memory is insufficient, you can appropriately reduce this
--calib-seqlen 2048 \ # Length of a single piece of text, if memory is insufficient, you can appropriately reduce this
--w-bits 4 \ # Bit number for weight quantization
--w-group-size 128 \ # Group size for weight quantization statistics
--work-dir $WORK_DIR # Folder storing Pytorch format quantization statistics parameters and post-quantization weight
$HF_MODEL \
--calib-dataset 'ptb' \
--calib-samples 128 \
--calib-seqlen 2048 \
--w-bits 4 \
--w-group-size 128 \
--work-dir $WORK_DIR
```

Typically, the above command doesn't require filling in optional parameters, as the defaults usually suffice. For instance, when quantizing the [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) model, the command can be condensed as:
Expand All @@ -64,7 +67,7 @@ lmdeploy chat turbomind ./internlm-chat-7b-4bit --model-format awq
Alternatively, you can start the gradio server and interact with the model through the web at `http://{ip_addr}:{port`

```shell
lmdeploy serve gradio ./internlm-chat-7b-4bit --server_name {ip_addr} --server_port {port}
lmdeploy serve gradio ./internlm-chat-7b-4bit --server_name {ip_addr} --server_port {port} --model-format awq
```

## Evaluation
Expand Down
34 changes: 34 additions & 0 deletions docs/en/serving/gradio.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Steps to create a huggingface online demo

## create space

First, register for a Hugging Face account. After successful registration, click on your profile picture in the upper right corner and select “New Space” to create one. Follow the Hugging Face guide to choose the necessary configurations, and you will have a blank demo space ready.

## A demo for LMDeploy

Replace the content of `app.py` in your space with the following code:

```python
from lmdeploy.serve.gradio.turbomind_coupled import run_local
from lmdeploy.messages import TurbomindEngineConfig

backend_config = TurbomindEngineConfig(max_batch_size=1, cache_max_entry_count=0.05)
model_path = 'internlm/internlm2-chat-7b'
run_local(model_path, backend_config=backend_config, server_name="huggingface-space")
```

Create a `requirements.txt` file with the following content:

```
lmdeploy
```

## FAQs

- ZeroGPU compatibility issue. ZeroGPU is more suitable for inference methods similar to PyTorch, rather than Turbomind. You can switch to the PyTorch backend or enable standard GPUs.
- Gradio version issue, versions above 4.0.0 are currently not supported. You can modify this in `app.py`, for example:
```python
import os
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.43.0")
```
2 changes: 1 addition & 1 deletion docs/en/serving/restful_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ lmdeploy serve gradio api_server_url --server-name ${gradio_ui_ip} --server-port

4. The `/v1/chat/interactive` api disables engaging in multiple rounds of conversation by default. The input argument `prompt` consists of either single strings or entire chat histories.

5. If you need to adjust other default parameters of the session, such as the content of fields like system. You can directly pass in the initialization parameters of the [dialogue template](https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py). For example, for the internlm-chat-7b model, you can set the `--meta_instruction` parameter when starting the `api_server`.
5. If you need to adjust other default parameters of the session, such as the content of fields like system. You can directly pass in the initialization parameters of the [dialogue template](https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py). For example, for the internlm-chat-7b model, you can set the `--meta-instruction` parameter when starting the `api_server`.

6. Regarding the stop words, we only support characters that encode into a single index. Furthermore, there may be multiple indexes that decode into results containing the stop word. In such cases, if the number of these indexes is too large, we will only use the index encoded by the tokenizer. If you want use a stop symbol that encodes into multiple indexes, you may consider performing string matching on the streaming client side. Once a successful match is found, you can then break out of the streaming loop.

Expand Down
66 changes: 66 additions & 0 deletions docs/zh_cn/advance/pytorch_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,72 @@ MODEL_MAP.update({

这样就可以利用多卡的优势,让更大的模型部署成为可能

## 模块调试

当模型的输出不符合预期时,我们会希望调试某个特定模块以确定添加的重写是否正确。`lmdeploy.pytorch` 提供了一些工具以帮助进行精度对齐。还是以上面提到的 `LlamaAttention` 模块为例。

首先,我们通过 transformers 的 API 得到想要调试的子模块的一个实例:

```python
import torch
from transformers import AutoModelForCausalLM

# get module
model_path = 'meta-llama/Llama-2-7b-chat-hf'
dtype = torch.float16
model = AutoModelForCausalLM.from_pretrained(model_path).to(torch.float16).cuda()
self_attn = model.model.layers[0].self_attn
```

然后,使用 `ModuleIOExtractor` 工具可以生成该模块的一组输入输出

```python
from lmdeploy.pytorch.tools.make_inputs import ModuleIOExtractor

# extract module input/output
input_ids = torch.tensor([[1, 2, 3, 4, 5]]).cuda()
extractor = ModuleIOExtractor(model, self_attn)
attn_args, attn_kwargs, attn_output = extractor.extract(input_ids)
```

重写模块的输入与原模块略有不同,主要体现在三方面:

1. 模型需要一些特殊输入输出,他们以 `StepContext` 的形式传入,可以使用 `make_step_context` 生成。
2. `input_ids``hidden_states` 等数据都被 continuous 化,可以使用 `continuous_tensor` 进行处理。
3. 由于 paged caching 的需要, `past_key_value` 需要被 page 化处理。

基于以上原因,我们要对提取的输入进行加工:

```python
from lmdeploy.pytorch.tools.make_inputs import make_step_context
from lmdeploy.pytorch.tools.layout_convert import continuous_tensor

# create patched input/output
context = make_step_context(input_ids,
kv_cache_dtype=dtype)
seq_length = context.seq_length
attn_kwargs['hidden_states'] = continuous_tensor(
attn_kwargs['hidden_states'],
seq_length)
attn_kwargs['past_key_value'] = context.kv_caches[0]
```

然后就可以启动重写,并比较结果正确性了。(注意输出也要 continuous 化后进行比较)

```python
from lmdeploy.pytorch.models import patch

# patch and test
patched_self_attn = patch(self_attn, extra_args=['context'])
patched_output = patched_self_attn.patched_forward(*attn_args,
**attn_kwargs,
context=context)
torch.testing.assert_close(patched_output[0],
continuous_tensor(attn_output[0], seq_length))
```

可以通过上述方法调试重写模块,直到精度满足预期。

## 附录

### context 结构
Expand Down
9 changes: 6 additions & 3 deletions docs/zh_cn/build.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ LMDeploy 提供了编译镜像 `openmmlab/lmdeploy-builder:cuda11.8`。使用之
在 lmdeploy 源码的根目录下,运行以下命令:

```shell
cd lmdeploy # lmdeploy 源码根目录
# lmdeploy 源码根目录
cd lmdeploy
bash builder/manywheel/build_all_wheel.sh
```

Expand Down Expand Up @@ -67,8 +68,10 @@ wheel 文件存放在目录 `builder/manywheel/cuda11.8_dist` 下。
```
- lmdeploy 编译安装:
```shell
apt install ninja-build # 安装更快的 Ninja
cd lmdeploy # lmdeploy 源码的根目录
# 安装更快的 Ninja
apt install ninja-build
# lmdeploy 源码的根目录
cd lmdeploy
mkdir build && cd build
sh ../generate.sh
ninja && ninja install
Expand Down
Loading

0 comments on commit 374818a

Please sign in to comment.