Zero Inference利用ZeRO stage 3的数据并行特性,能够将模型分布到多张GPU上,或者Offload到内存或者NVMe上,推理单GPU无法加载的模型
Zero Inference是数据并行的推理,因此需要在各个GPU同时启动推理进程并进行model.forward
,否则会卡住
- 张量并行(对模型手动切分)
- 通信量:
$O(batch \times len \times layer \times hidden)$
- 计算和通讯不能同时进行
- 通信量:
- 流水线并行(
AutoModelForCausalLM.from_pretrained(device_map="auto")
)- 通信量:
$O(batch \times len \times layer \times hidden)$
- 计算和通讯不能同时进行
- 通信量:
- 数据并行(ZeRO Inference)
- 通信量:
$O(hidden \times hidden)$
- 计算和通讯可以同时进行
- 通信量:
本仓库实现了LLaMA和LLaMA 2的flash attention,可在推理时启动,降低显存占用。 由于flash attention不支持自定义的attention mask,启动flash attention时,batch size必须设为1并关闭任何padding。
{"text": xxx}
{"text": xxx}
bash scripts/run_zero_inference.sh
可传入的参数有max_new_tokens
、min_new_tokens
、do_sample
、num_beams
、temperature
、top_k
、top_p
、repetition_penalty
。
具体说明见huggingface文档
bash scripts/run_zero_inference_backend_without_trainer.sh
devices
:指令使用哪几个显卡,格式同CUDA_VISIBLE_DEVICES
base_port
:后端服务监听端口,打开[base_port
,base_port
+num_devices
- 1]的端口
运行src/evaluation.ipynb
,由于ZeRO Inference要求多个model.forward
必须同时运行,必须设置synced_worker=True
,同时保证客户端连接上了每个后端进程