Skip to content

Commit

Permalink
optmize chatglm3 in pytorch engine(#1215)
Browse files Browse the repository at this point in the history
* optmize chatglm3

* update docs
  • Loading branch information
grimoire authored Mar 1, 2024
1 parent 0430349 commit f81404a
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 106 deletions.
6 changes: 5 additions & 1 deletion docs/en/inference/pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,12 @@ print(response)

## FAQs

- *RuntimeError: context has already been set*. If you got this for tp>1 in pytorch backend. Please make sure the python script has following
- **RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase**.

If you got this for tp>1 in pytorch backend. Please make sure the python script has following

```python
if __name__ == '__main__':
```

Generally, in the context of multi-threading or multi-processing, it might be necessary to ensure that initialization code is executed only once. In this case, `if __name__ == '__main__':` can help to ensure that these initialization codes are run only in the main program, and not repeated in each newly created process or thread.
6 changes: 5 additions & 1 deletion docs/zh_cn/inference/pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,12 @@ print(response)

## FAQs

- *RuntimeError: context has already been set*. 如果你在使用 tp>1 和 pytorch 后端的时候,遇到了这个错误。请确保 python 脚本中有下面内容作为入口
- **RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase**.

如果你在使用 tp>1 和 pytorch 后端的时候,遇到了这个错误。请确保 python 脚本中有下面内容作为入口

```python
if __name__ == '__main__':
```

一般来说,在多线程或多进程上下文中,可能需要确保初始化代码只执行一次。这时候,`if __name__ == '__main__':` 可以帮助确保这些初始化代码只在主程序执行,而不会在每个新创建的进程或线程中重复执行。
7 changes: 5 additions & 2 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass, field
from typing import Any
from typing import Any, Dict

import torch

Expand Down Expand Up @@ -58,6 +58,7 @@ class ModelConfig:
multi_query_attention: bool = False
json_config: dict = field(default_factory=dict)
hf_config: Any = None
init_kwargs: Dict[str, Any] = field(default_factory=dict)

def get_head_size(self):
"""get head size."""
Expand Down Expand Up @@ -109,14 +110,16 @@ def __build_chatglm():
bos_token_id = hf_config.bos_token_id
if bos_token_id is None:
bos_token_id = hf_config.pad_token_id
init_kwargs = dict(empty_init=False)
return ModelConfig(
hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_layers,
num_attention_heads=hf_config.num_attention_heads,
num_key_value_heads=hf_config.multi_query_group_num,
bos_token_id=bos_token_id,
eos_token_id=hf_config.eos_token_id,
head_dim=head_dim)
head_dim=head_dim,
init_kwargs=init_kwargs)

def __build_gemma():
return ModelConfig(
Expand Down
9 changes: 6 additions & 3 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ def _build_model(self,
hf_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code)
trust_remote_code=trust_remote_code,
**self.model_config.init_kwargs)
hf_model.eval()
hf_model.config.use_cache = True

Expand Down Expand Up @@ -671,7 +672,8 @@ def _broadcast_config(cache_config):
model = AutoModelForCausalLM.from_config(
config,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code)
trust_remote_code=trust_remote_code,
**model_config.init_kwargs)
if rank == 0:
device_map = _create_device_map(model, world_size)
_add_adapters(model, adapters)
Expand All @@ -688,7 +690,8 @@ def _broadcast_config(cache_config):
model_path,
torch_dtype=torch_dtype,
device_map=device_map,
trust_remote_code=trust_remote_code)
trust_remote_code=trust_remote_code,
**model_config.init_kwargs)
_load_adapters(param_model, adapters, device_map=device_map)
__load_state_dict_assign(param_model, model)
param_model = param_model.to('meta')
Expand Down
185 changes: 86 additions & 99 deletions lmdeploy/pytorch/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor
from transformers.modeling_outputs import BaseModelOutputWithPast

from ..dist_utils import rowwise_parallelize_linear_fn, try_to_local
from ..dist_utils import (colwise_parallelize_linear,
rowwise_parallelize_linear_fn, try_to_local)
from ..kernels import paged_attention_fwd
from .functional import fill_kv_cache

Expand Down Expand Up @@ -50,28 +51,27 @@ def split_tensor_along_last_dim(
return tensor_list


# @torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor,
rope_cache: torch.Tensor) -> torch.Tensor:
# x: [sq, b, np, hn]
sq, _, np, _ = x.size(0), x.size(1), x.size(2), x.size(3)
rot_dim = rope_cache.shape[-2] * 2
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
# truncate to support variable sizes
sq, hn = x.size(0), x.size(-1)
xslice = x[..., :hn // 2]
rope_cache = rope_cache[:sq]
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
x_out2 = torch.stack(
xshaped = xslice.unflatten(-1, (-1, 2))
rope_cache = rope_cache.unsqueeze(2)

# inplace
torch.stack(
[
xshaped[..., 0] * rope_cache[..., 0] -
xshaped[..., 1] * rope_cache[..., 1],
xshaped[..., 1] * rope_cache[..., 0] +
xshaped[..., 0] * rope_cache[..., 1],
],
-1,
out=xshaped,
)
x_out2 = x_out2.flatten(3)
return torch.cat((x_out2, x_pass), dim=-1)
return x


class PatchedSelfAttention(nn.Module):
Expand All @@ -81,28 +81,55 @@ class PatchedSelfAttention(nn.Module):
the same size.
"""

def _distribute_qkv_linear(self, mod: nn.Module, device_mesh: DeviceMesh):
"""distribute qkv linear."""
sections = [
self.num_attention_heads_per_partition *
self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head,
]
for name, param in mod.named_parameters():
splited_param = param.split(sections, dim=0)
updated_param = []
for p in splited_param:
dist_tensor = distribute_tensor(p, device_mesh, [Shard(0)])
dist_tensor = try_to_local(dist_tensor)
updated_param.append(dist_tensor)
param = torch.cat(updated_param)
dist_param = torch.nn.Parameter(param)
mod.register_parameter(name, dist_param)

def _distribute_qkv_lora_linear(self, module: nn.Module,
device_mesh: DeviceMesh):
"""distribute qkv lora linear."""
to_local = True
self._distribute_qkv_linear(
module.base_layer,
device_mesh=device_mesh,
)
for mod in module.lora_A.values():
colwise_parallelize_linear(mod,
device_mesh=device_mesh,
to_local=to_local)
for mod in module.lora_B.values():
self._distribute_qkv_linear(
mod,
device_mesh=device_mesh,
)
module._tp_mode = 'colwise'

def _distribute_partition_fn(self, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['query_key_value']:
sections = [
self.num_attention_heads_per_partition *
self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head,
]
for name, param in mod.named_parameters():
splited_param = param.split(sections, dim=0)
updated_param = []
for p in splited_param:
dist_tensor = distribute_tensor(p, device_mesh, [Shard(0)])
dist_tensor = try_to_local(dist_tensor)
updated_param.append(dist_tensor)
param = torch.cat(updated_param)
dist_param = torch.nn.Parameter(param)
mod.register_parameter(name, dist_param)
from peft.tuners.lora import Linear as LoraLinear
if isinstance(mod, LoraLinear):
self._distribute_qkv_lora_linear(mod, device_mesh)
else:
self._distribute_qkv_linear(mod, device_mesh)
elif mod_name in ['dense']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
Expand All @@ -117,11 +144,8 @@ def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
def _contiguous_batching_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
kv_cache: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
# hidden_states: [sq, b, h]
Expand All @@ -140,6 +164,12 @@ def _contiguous_batching_forward(

context = self.context.context
history_lengths = context.history_lengths
max_seq_length = context.max_seq_length
q_start_loc = context.q_start_loc
q_seq_length = context.q_seq_length
kv_seq_length = context.kv_seq_length
block_offsets = context.block_offsets

mixed_x_layer = self.query_key_value(hidden_states)

if self.multi_query_attention:
Expand All @@ -154,18 +184,12 @@ def _contiguous_batching_forward(
],
dim=-1,
)
query_layer = query_layer.view(query_layer.size()[:-1] + (
self.num_attention_heads_per_partition // world_size,
self.hidden_size_per_attention_head,
))
key_layer = key_layer.view(key_layer.size()[:-1] + (
self.num_multi_query_groups_per_partition // world_size,
self.hidden_size_per_attention_head,
))
value_layer = value_layer.view(value_layer.size()[:-1] + (
self.num_multi_query_groups_per_partition // world_size,
self.hidden_size_per_attention_head,
))
query_layer = query_layer.unflatten(
-1, (-1, self.hidden_size_per_attention_head))
key_layer = key_layer.unflatten(
-1, (-1, self.hidden_size_per_attention_head))
value_layer = value_layer.unflatten(
-1, (-1, self.hidden_size_per_attention_head))
else:
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition // world_size,
Expand All @@ -178,47 +202,31 @@ def _contiguous_batching_forward(
value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

# [b, sq, np, hn]
query_layer, key_layer, value_layer = [
k.transpose(0, 1) for k in [query_layer, key_layer, value_layer]
]

# adjust key and value for inference
if kv_cache is not None:
cache_k, cache_v = kv_cache
q_start_loc = context.q_start_loc
q_seq_length = context.q_seq_length

q_start_loc: torch.Tensor
history_lengths = q_seq_length.new_tensor(history_lengths)
kv_seq_length = q_seq_length + history_lengths
max_seq_len = q_seq_length.max().item()
fill_kv_cache(key_layer[0],
value_layer[0],
cache_k,
cache_v,
q_start_loc,
q_seq_length,
block_offsets=context.block_offsets,
history_lengths=history_lengths)

if use_cache:
kv_cache = (key_layer, value_layer)
else:
kv_cache = None
cache_k, cache_v = kv_cache
fill_kv_cache(key_layer[0],
value_layer[0],
cache_k,
cache_v,
q_start_loc,
q_seq_length,
block_offsets=block_offsets,
history_lengths=history_lengths,
context=context)

# ==================================
# core attention computation
# ==================================

context_layer = torch.empty_like(query_layer)

block_offsets = context.block_offsets

context_layer = query_layer
paged_attention_fwd(query_layer,
cache_k,
cache_v,
Expand All @@ -227,7 +235,7 @@ def _contiguous_batching_forward(
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_seq_len)
max_seqlen=max_seq_length)

context_layer = context_layer.transpose(1, 0).flatten(-2)

Expand All @@ -250,10 +258,8 @@ def forward(
):
return self._contiguous_batching_forward(
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache,
use_cache,
)


Expand Down Expand Up @@ -292,15 +298,9 @@ def _contiguous_batching_forward(
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor],
...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = True,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None):
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache # noqa: E501
return_dict = return_dict if return_dict is not None else self.config.use_return_dict # noqa: E501
inputs_embeds: Optional[torch.Tensor] = None):
output_hidden_states = False
use_cache = True

batch_size, seq_length = input_ids.shape

Expand All @@ -312,12 +312,6 @@ def _contiguous_batching_forward(
past_key_values = self.get_prompt(batch_size=batch_size,
device=input_ids.device,
dtype=inputs_embeds.dtype)
if attention_mask is not None:
attention_mask = torch.cat([
attention_mask.new_ones(
(batch_size, self.pre_seq_len)), attention_mask
],
dim=-1)

# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
Expand All @@ -339,11 +333,6 @@ def _contiguous_batching_forward(
use_cache=use_cache,
output_hidden_states=output_hidden_states)

if not return_dict:
return tuple(v for v in [
hidden_states, presents, all_hidden_states, all_self_attentions
] if v is not None)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
Expand Down Expand Up @@ -371,6 +360,4 @@ def forward(
full_attention_mask=full_attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
)

0 comments on commit f81404a

Please sign in to comment.