lmdeploy.pytorch 被设计用来简化新模型的支持以及原型的开发,新模型的支持依赖于 patch 机制,对原模型做修改以及功能添加,以期可以最大程度上复用模型的原始实现,减少工作量。
我们以 transformers 中的 llama 实现来介绍模型支持的流程
在开始之前,我们首先要了解一下模型的输入。lmdeploy.pytorch 的输入与标准 transformers 模型的输入略有不同,差异主要体现在如下方面:
- 由于支持了 continuous batching,一个 batch 的输入
input_ids
会被拼接成一维的长序列,然后unsqueeze(0)
来保证输入维度与 transformers 中相同。这样的输入不会影响 MLP 以及 RMSNorm 等模块的计算。 - 由于添加了对 paged attention 的支持,
past_key_value
不再是原来的大小,而是一组形状为[num_blocks, block_size, num_heads, head_dim]
的 cache 块,num_blocks 为总 block 数量,由可用显存大小决定,block_size 为预设的块大小。这样的输入改变会影响到 LlamaModel 和 LlamaAttention 的计算,因此要对这两个模块的实现进行修改。 - 由于上述输入的改变,模型中需要一些额外的输入来支持推理,比如 batch 中的序列起始位置和长度,kv cache 的 block table 等。这些输入并不在模块的 forward 参数列表中,我们需要维护一个上下文以获得这些输入。
上面的输入改动会影响 LlamaModel 和 LlamaAttention,首先我们来实现新的 LlamaModel,这是对原始实现的简化,我们删除了很多检查代码,以避免由于输入改变造成的断言失败,仅保留了最小程度的代码:
# lmdeploy/pytorch/models/llama.py
class LlamaModel(nn.Module):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""Rewrite implementation of LlamaModel.forward."""
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
# decoder layers
for idx, decoder_layer in enumerate(self.layers):
past_key_value = past_key_values[idx]
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
然后是对 LlamaAttention 模块的改写。按顺序实现如下操作:
- kqv proj
- rotary embedding
- 填充 kv cache
- MHA 计算
- o proj
continuous batching 和 kv cache 的改动对该模块的影响比较大
# lmdeploy/pytorch/models/llama.py
from lmdeploy.pytorch.kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd
class LlamaAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite of LlamaAttention.forward."""
context = self.context.context
history_lengths = context.history_lengths
position_ids_1d = context.position_ids_1d
block_offsets = context.block_offsets
# qkv proj
query_states = q_proj(hidden_states)
key_states = k_proj(hidden_states)
value_states = v_proj(hidden_states)
query_states = query_states.view(-1, num_heads, head_dim)
key_states = key_states.view(-1, num_kv_heads, head_dim)
value_states = value_states.view(-1, num_kv_heads, head_dim)
# rotary embedding
max_seq_len = position_ids.size(-1)
kv_seq_len = max_seq_len + max(history_lengths)
if kv_seq_len >= self.rotary_emb.max_seq_len_cached:
cos, sin = self.rotary_emb(value_states,
seq_len=kv_seq_len + 128)
query_states, key_states = apply_rotary_pos_emb(
query_states,
key_states,
self.rotary_emb.cos_cached,
self.rotary_emb.sin_cached,
position_ids,
position_ids_1d,
q_embed=query_states,
k_embed=key_states)
# fill kv cache
kv_seq_length = context.kv_seq_length
q_seq_length = context.q_seq_length
q_start_loc = context.q_start_loc
fill_kv_cache(key_states,
value_states,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
block_offsets=block_offsets,
history_lengths=history_lengths,
context=context)
# attention
attn_output = query_states
block_size = past_key_value[0].size(1)
paged_attention_fwd(
query_states,
past_key_value[0],
past_key_value[1],
attn_output,
block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_seq_len,
)
hidden_size = num_heads * head_dim
attn_output = attn_output.reshape(*hidden_states.shape[:-1], hidden_size)
# o proj
attn_output = o_proj(attn_output)
return attn_output, None, past_key_value
上面的代码有几处值得注意的地方,首先是 context 对象。我们需要 history_lengths、block_offsets 等参数辅助运算,这些参数无法通过模型的 forward 函数传递进来。因此我们维护了一个 context 对象,把几乎所有可能用到的输入参数都保存在其中,方便在各个模块间共享。context 对象可以通过 self.context.context
来访问,结构可以参考 context-结构。
另一个值得注意的地方就是自定义 kernel,由于输入形式的改变,原来的 LlamaAttention 实现变得不再适用,为了保证推理的速度和正确性,我们在 lmdeploy.pytorch.kernels 中实现了许多自定义的 triton kernel,上面的模块中就用到了 apply_rotary_pos_emb
,fill_kv_cache
和 paged_attention_fwd
,分别负责实现 rotary embedding,填充 kv cache 还有 attention 的计算。
有了上述的两个模块后,还需要将他们注册到 lmdeploy/pytorch/models/module_map.py
中,进行原模块与 patch 模块的映射
# lmdeploy/pytorch/models/module_map.py
MODEL_MAP.update({
'transformers.models.llama.modeling_llama.LlamaAttention':
'lmdeploy.pytorch.models.llama.LlamaAttention',
'transformers.models.llama.modeling_llama.LlamaModel':
'lmdeploy.pytorch.models.llama.LlamaModel'
})
完成注册后,Engine 在启动时就会将这两个模块 patch 成新的实现,完成后续的部署任务。
为了支持 Tensor 并发,需要对模型的权重做切分。让我们试着为上面接入的 Llama 模型添加 TP 的支持。
Llama 中涉及到 Tensor 并发的模块是 LlamaAttention 中的 qkvo proj 和 LlamaMLP 中的 gate,up 和 down proj。其中 o_proj 和 down_proj 需要按行切分,剩下的按列切分。我们可以在对应的模块中实现 _distribution_partition_fn
函数:
# lmdeploy/pytorch/models/llama.py
from ..dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
class LlamaAttention(nn.Module):
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['q_proj', 'k_proj', 'v_proj']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['o_proj']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
class LlamaMLP(nn.Module):
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['gate_proj', 'up_proj']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['down_proj']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
_distribute_partition_fn
会在加载模型权重时被调用,对应的权重会被按照特定的形式分配到对应的设备中。
按照目前的方案切分后的权重,需要对 o_proj 和 down_proj 的结果进行 all_reduce 操作才能得到正确的结果。可以选择将 all_reduce 放在模型的 forward 函数中,也可以选择另一种方案,添加 _distribute_output_fn
函数:
# lmdeploy/pytorch/models/llama.py
import torch.distributed as dist
class LlamaAttention(nn.Module):
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs
class LlamaMLP(nn.Module):
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs)
return outputs
最后别忘了将 LlamaMLP 也注册进 module_map 中
# lmdeploy/pytorch/models/module_map.py
MODEL_MAP.update({
'transformers.models.llama.modeling_llama.LlamaMLP':
'lmdeploy.pytorch.models.llama.LlamaMLP'
})
这样就可以利用多卡的优势,让更大的模型部署成为可能
当模型的输出不符合预期时,我们会希望调试某个特定模块以确定添加的重写是否正确。lmdeploy.pytorch
提供了一些工具以帮助进行精度对齐。还是以上面提到的 LlamaAttention
模块为例。
首先,我们通过 transformers 的 API 得到想要调试的子模块的一个实例:
import torch
from transformers import AutoModelForCausalLM
# get module
model_path = 'meta-llama/Llama-2-7b-chat-hf'
dtype = torch.float16
model = AutoModelForCausalLM.from_pretrained(model_path).to(torch.float16).cuda()
self_attn = model.model.layers[0].self_attn
然后,使用 ModuleIOExtractor
工具可以生成该模块的一组输入输出
from lmdeploy.pytorch.tools.make_inputs import ModuleIOExtractor
# extract module input/output
input_ids = torch.tensor([[1, 2, 3, 4, 5]]).cuda()
extractor = ModuleIOExtractor(model, self_attn)
attn_args, attn_kwargs, attn_output = extractor.extract(input_ids)
重写模块的输入与原模块略有不同,主要体现在三方面:
- 模型需要一些特殊输入输出,他们以
StepContext
的形式传入,可以使用make_step_context
生成。 input_ids
,hidden_states
等数据都被 continuous 化,可以使用continuous_tensor
进行处理。- 由于 paged caching 的需要,
past_key_value
需要被 page 化处理。
基于以上原因,我们要对提取的输入进行加工:
from lmdeploy.pytorch.tools.make_inputs import make_step_context
from lmdeploy.pytorch.tools.layout_convert import continuous_tensor
# create patched input/output
context = make_step_context(input_ids,
kv_cache_dtype=dtype,
num_key_value_heads=32)
seq_length = context.q_seq_length
attn_kwargs['hidden_states'] = continuous_tensor(
attn_kwargs['hidden_states'],
seq_length)
attn_kwargs['past_key_value'] = context.kv_caches[0]
然后就可以启动重写,并比较结果正确性了。(注意输出也要 continuous 化后进行比较)
from lmdeploy.pytorch.models import patch
# patch and test
patched_self_attn = patch(self_attn, extra_args=['context'])
with torch.inference_mode():
patched_output = patched_self_attn.patched_forward(*attn_args,
**attn_kwargs,
context=context)
torch.testing.assert_close(patched_output[0],
continuous_tensor(attn_output[0], seq_length))
可以通过上述方法调试重写模块,直到精度满足预期。
@dataclass
class StepContext:
"""context of Model.
"""
inputs: ModelInputs
block_offsets: torch.LongTensor
position_ids: torch.LongTensor
position_ids_1d: torch.LongTensor
q_start_loc: torch.LongTensor
history_lengths: torch.LongTensor
seq_length: torch.LongTensor
max_seq_length: int
kv_seq_length: torch.LongTensor
kv_caches: List
is_decoding: bool
world_size: int = 1
json_config: Dict = None
local_adapter_ids: torch.LongTensor = None
global_adapter_ids: torch.LongTensor = None
adapter_offsets: torch.LongTensor = None
max_rank: int = 0
- 如何访问 patch 前的模块?
有时我们只希望在函数前后加一个 hook 代码,不希望大段的拷贝函数,可以通过 self.origin_mod
访问 patch 前的模块。
- 非 transformers 官方的模型该如何注册?
一些模型的实现代码可能是以 remote code 的形式添加的,这样的模块无法通过完整的 qualname 来定位。lmdeploy.pytorch 支持使用缩写的模块名进行注册:
MODULE_MAP.update({
'modeling_internlm.InternLMAttention':
'lmdeploy.pytorch.models.internlm.PatchedInternLMAttention',
})
Note
缩写的优先级会更低,有条件的话还是鼓励使用完整的 qualname 进行注册。
- 模块出现同名但不同实现怎么处理?
目前推荐的做法是同名就映射到同一个实现中,然后在实现内部根据模块的固有参数来判断模型该使用的类型,以 baichuan2 7b/13b 为例:
class BaichuanModel(nn.Module):
def forward(self, ...):
if self.config.num_hidden_layers == 32:
return forward_7b(...)
else:
return forward_default(...)
- 如果希望在推理前对模块进行初始化?
可以实现模块的 _update_model_fn
函数,它会在模块的权重都加载完,完成 TP 权重切分后被调用
class LlamaAttention:
def _update_model_fn(self):
# ADD YOUR CODE HERE