-
Notifications
You must be signed in to change notification settings - Fork 5.6k
paddle进程间tensor传输设计文档 paddle.multiprocessing
- 此为paddle.multiprocessing设计方案,相关实现已经由此PR实现:#37302
- 剩余未完成的功能较多,因此开放此文档,提供给外部开发者使用,希望外部开发者贡献完善。
- 此文档为2023飞桨黑客马拉松活动使用。
本文档,主要设计新增paddle.multiprocessing
模块。通过自定义Tensor序列化、反序列化方式,使用共享内存、cudaIpc等技术,实现paddle Tensor在进程间快速传输、共享。
实现 paddle.multiprocessing
模块,可在多进程间,方便快捷的传输Tensor。
-
功能支持: 动态图下,支持
paddle.Tensor, paddle.ParamBase
类型Tensor,在进程间传输,传输后Tensor属性与原来一致。(此处为基于LoDTensor的Tensor类型)支持CPU/GPU。 - 全局引用计数: 对共享的Tensor,支持全局引用计数。
- 多平台支持: 支持windows/mac平台,支持ROCM设备
-
共享Tensor生命周期管理 :
multiprocessing
进程间Tensor传输,是典型的生产者消费者场景。需要跨进程实现tensor的引用计数,确保tensor正确析构。
- 速度、功能符合预期: 在linux平台上,支持CPU/GPU传输Tensor,传输后的功能正常,可正常计算、修改。
- 不发生显存泄露: 支持初步的共享Tensor生命周期管理,尽量不发生显存,内存泄露。
竞品pytorch对进程间Tensor传输的支持,实现较早,目前整体的支持较为完备。 对于cpu Tensor进程间传输,支持linux、windows、mac平台。gpu Tensor支持cuda、rocm设备。具体支持情况的见下表。
torch.multiprocessing 设备支持列表
设备 | 平台 | 方案 | 补充 | |
---|---|---|---|---|
Host/CPU | linux | file_descriptor 、 file_system | file_descriptor 内存泄露风险小。打开文件获得句柄后,立即删除文件。多进程传输句柄。句柄关闭后文件系统释放存储。 | |
Host/CPU | win32 | file_system | file_system 存储为文件形式。多进程传输文件名。生命周期结束后需要删除文件。 | |
Host/CPU | mac | file_system | 同上 | |
Device/GPU | CUDA(linux) | share_cuda | 使用cudaIpcMemHandle 传输 |
|
Device/GPU | ROCM(linux) | share_cuda | 使用hipIpcMemHandle 传输 |
从代码结构上分析,下面是简要的代码模块介绍:
torch.multiprocessing 主要分为python层和C++层。
其中,python层主要有通过自定义ForkingPickler函数,改写Tensor相关类型的变量的序列化函数。在reduction.py
中定义了多种reduce、rebuild
方法。
python层API调用StorageSharing.cpp
中c++ api, 实现Tensor到序列化、反序列化。并且还支持了进程间Tensor引用计数(主要使用了RefcountedMapAllocator)。
此外cuda Tensor的引用计数支持较为复杂,这里实现了单独的CudaIPCTypes
来支持。主要原理是将引用计数的变量,写入到共享内存中,通过将引用计数的变量与Cuda IPC发送的变量绑定的方式(c10::detail::UniqueVoidPtr
),实现了全局引用计数。
主要增加Tensor存储这一部分的。
Tensor方面:当前计划只支持 LoDTensor,能够满足绝大部分动态图场景使用需求。 此外,对于共享的Tensor, 修改了部分 allocator 代码,以存储共享的Tensor变量。
设备支持方面:目前支持CPUPlace, CUDAPlace
, 支持CUDAPinnedPlace
(注:共享后应变为CPUPlace,PinnedPlace与share的状态不共存)。
cpu部分在mmap_allocator
,新增了RefCountMemoryMapAllocation
, 支持共享显存分配,读取,修改,写入,全局计数。gpu部分实现了单独的cuda_ipc_allocator
模块,可以将传输后的cuda ipc handle转化为Tensor holder的allocation。
注:
- 这里的
mmap_allocator
,cuda_ipc_allocator
目前主要在pybind的对外接口暴露的Tensor序列化/反序列化接口中使用。没有在其他地方使用。
整体设计主要参考pytorch的设计架构。我们在python层和c++层新增了部分序列化代码。paddle multiprocessing代码架构示意图如下:
python层主要新增了paddle.multiprocessing
文件夹,主要做了两件事情:
- 对
原始的multiprocessing做了封装
,import原始multiprocessing的所有方法。 - 添加了针对Paddle Tensor类型数据的
自定义进程间pickle协议函数
。
下图为整体Paddle Tensor进程间序列化的全部流程图。Tensor(VarBase) 先reduce为LoDTensor,然后LoDTensor在根据Tensor所在的设备(cpu/gpu)调用相应的c++函数,reduce成为handle文件。将handle传递给其他进程后,其他进程相应的反序列化函数即可重新构造为Tensor。
C++ 层主要分为pybind.cc中间的接口函数、进程间 Tensor 存储 allocation代码两部分。
pybind.cc主要定义的参数,返回值如下:
def("_share_cuda", [](LoDTensor self) {
// ...
return py::make_tuple(_handle,
(py::size_t) offset_bytes,
data_size, type_idx,
vectorize(self.dims()),
self.lod(), device_id);
}
.def("_new_shared_cuda", [](py::tuple t) {
LoDTensor tensor;
// ... 返回解析后的Tensor
return tensor;
}
.def("_share_filename", [](LoDTensor &self) {
// ...
return py::make_tuple(mmap_writer_allocation->ipc_name(),
mmap_writer_allocation->size(),
type_idx, vectorize(self.dims()), self.lod());
}
.def("_new_shared_filename", [](py::tuple t) {
LoDTensor tensor;
// ... 返回解析后的Tensor
return tensor;
}
反序列化函数,新建了一个空的LoDTensor, 将handle对应的数据取出来,变成 Allocation 类型,并且将空LoDTensor的 数据重置为 Allocation。 最后设置其他 Tensor 属性,返回给用户即可。
cpu部分在mmap_allocator.h
中,新增了MemoryMapAllocation
, RefCountMemoryMapAllocation
, 支持共享显存分配,读取,修改,写入,全局计数。
gpu部分实现了单独的cuda_ipc_allocator
模块,可以将传输后的cuda ipc handle转化为Tensor holder的CudaIpcAllocation
。同时提供了GetIpcBasePtr
接口,支持ipc handle到数据指针的转换。
python实现方案基本参考竞品设计思路,与torch基本一致。 c++层方面:
- cpu Tensor传输支持了
file_system
方案。与paddle之前的实现设计保持一致。 - gpu Tensor传输支持
cudaIpcMemHandle
方案。与竞品一致。
根据前面对于竞品的介绍,cpu Tensor进程间共享,竞品主要提供了两种方案: 方案一: file_system
// linux
int fd = shm_open(ipc_name.c_str(), flags, 0600);
// 进程间传递 ipc_name,子进程打开文件
优点:
- 较为简单,进程间传递文件名即可( 此处的 ipc name)。
缺点:
- Tensor以文件形式留存在文件系统内,进程异常退出后,可能发生文件残留。
方案二: file_descriptor
// open filename first
if((fd = shm_open(filename_.c_str(), flags, (mode_t)0600)) == -1) {
TORCH_CHECK(false, "error");
}
// 保存此处的 fd,用于进程间传输。
// delete file
if (shm_unlink(filename_.c_str()) == -1) {
TORCH_CHECK(false, "error");
}
优点:
- 传递文件描述符,比较小概率产生内存泄露问题。
缺点:
- 只能支持linux操作系统。
建议:目前paddle侧已有的初步开发与file_system 方案类似,将此方案改进为支持paddle.multiprocessing
的成本比较低。后期,可以进一步支持file_descriptor方案。
方案: CUDA Ipc 相关 API 可以支持进程间显存共享。 目前只有此方法支持。
使用API cudaIpcGetMemHandle 可以获取cudaIpcMemHandle_t
, 将cudaIpcMemHandle_t
转化为string,传递给子进程。子进程获取handle之后,使用 cudaIpcOpenMemHandle API 打开,获取对应的显存地址,即可访问父进程的显存地址空间,实现进程间显存共享。
// 获取进程间显存的 handle
__host__cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr )
// -> 进程间传输handle
// 使用handle打开显存地址
__host__cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags )
注:
- IPC 功能仅限于,在Linux 操作系统上支持统一寻址的设备。
- cudaIpcGetMemHandle 功能需要使用CudaMalloc的原始地址。
生产者、消费者场景,生命周期管理比较复杂。
方案一:用户使用event变量自己维护,生产者、消费者状态
## producer
# send Tensors, do something
send(Tensors)
event.wait()
event.clear()
## consumer
# receive Tensors and use them
Tensors = receive()
use_Tensors()
event.set()
del Tensors
建议使用 multiprocessing.Event()
进行进程间传输握手。
-
对于cpu Tensor生命周期:
- 目前已经初步实现了全局计数方案。但通过 event 变量传输握手,更安全,不容易发生内存泄露。
-
对于gpu Tensor生命周期:当前阶段,用户必须自己保证:生产者消费者状态。
- 生产者:用户需要保证消费者消费前,Tensor不被析构。
- 消费者,析构时候,close IpcMemoryHandle 即可,不会造成文件残留。
总体来看,由于gpu ipc传递方式不会显式存在文件残留,只需要用户自己维护生产者消费者状态。传输完成后,原始Tensor按照生产者的引用计数管理即可。造成显存泄露的风险较小。
方案二: 全局计数方案:实现共享Tensor全局引用计数。用户不需要显式进行,进程间同步
## producer
# send Tensors,
send(Tensors)
...
## consumer
x = queue.get()
# do somethings with x
del x
在共享内存总开辟额外空间,实现全局引用计数。 消费者接收后,增加了Tensor的引用计数,生产者引用计数减少为0后,可删除文件。 实现方案:基本实现原理参考竞品,开辟共享内存中的额外空间存放计数信息。
- 对于cpu Tensor
-
【已实现】 实现
RefCountMemoryMapAllocation
,在分配给Tensor分配存储数据空间时,额外增加一些字节,存储引用计数。所有IPC tensor,在建立时,引用计数+1,析构后-1,最后引用计数为0,删除共享内存中的文件。
-
【已实现】 实现
- 对于gpu Tensor
- 参考竞品,计划实现
allocation/utils/cuda_ipc_helper.h
,支持CudaIPCSentData
,CudaIPCRefcountedFiles
等功能,将ipc 传输后的Tensor与CudaIPCSentData
使用UniqueVoidPtr
绑定。全局引用计数。
- 参考竞品,计划实现
总的来讲,event消息维护方案对用户编写代码有一定要求,全局计数方案更易用。
建议:可以先合入初步版本,支持方案一使用(支持 cpu/gpu tensor),后续完善方案二(暂时支持 cpu tensor),完善全局引用计数支持。
需要自测的项目有:
- 不同设备,cpu、gpu传输
- Tensor异步进程读写检验
- 不同Tensor,Paddle.Tensor, Paddle.ParamBase
- Tensor属性,stop gradient,lod
- 内存泄露检测(多次传输)
paddle.multiprocessing
模块的主要场景有,数据传输、参数共享两种。
主要应用点有:
- 传输:DataLoader:多进程造数据,将Tensor传输到主进程。(N->1)造数据慢,训练快。
- 共享:异步多进程预测、训练:主进程传输后,主进程和子进程各自独立进行预测、训练。
import paddle
import paddle.multiprocessing as mp
paddle.set_device("cpu")
def fill_Tensor(queue, event):
data = queue.get()
data[:] = 5
event.set()
Tensor = paddle.zeros([5, 5], dtype="float32")
queue = mp.Queue()
event = mp.Event()
queue.put(Tensor)
process = ctx.Process(target=fill_Tensor, args=(queue, event))
process.daemon = True
process.start()
event.wait(30)
print(Tensor)
# Tensor(shape=[5, 5], dtype=float32, place=CPUPlace, stop_gradient=True,
# [[5., 5., 5., 5., 5.],
# [5., 5., 5., 5., 5.],
# [5., 5., 5., 5., 5.],
# [5., 5., 5., 5., 5.],
# [5., 5., 5., 5., 5.]])
process.join(4)
多个进程使用了相同地址的参数
import paddle.multiprocessing as mp
paddle.set_device("cpu")
from model import MyModel
def evaluate(model):
# Construct data_loader, etc.
model.eval()
for data, labels in data_loader:
predict = loss_fn(model(data))
calculate_acc(predict, labels)
if __name__ == '__main__':
num_processes = 4
model = MyModel()
model.share_memory() # 拷贝到共享内存
processes = []
for rank in range(num_processes):
p = mp.Process(target=evaluate, args=(model,))
p.start()
processes.append(p)
for p in processes:
p.join()