diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 7ae70d46f..ddf2478df 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -7,6 +7,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node +from lightllm.distributed.pynccl import PyNcclCommunicator logger = init_logger(__name__) @@ -35,7 +36,11 @@ def alloc_kv_move_buffer(self, max_req_total_len): return def send_to_decode_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int + self, + move_tasks: List[KVMoveTask], + mem_managers: List["Deepseek2MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -49,7 +54,7 @@ def send_to_decode_node( cur_mem = mem_managers[cur_device_index] for layer_index in range(cur_mem.layer_num): move_buffer = cur_mem._get_kv_move_data(move_token_indexes, layer_index) - dist.send(move_buffer, dst=1) + nccl_comm.send(move_buffer, dst=1) return def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): @@ -61,7 +66,11 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): return move_buffer def receive_from_prefill_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -76,7 +85,7 @@ def receive_from_prefill_node( move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, self.head_num, self.head_dim) for layer_index in range(self.layer_num): - dist.recv(recive_buffer, src=0) + nccl_comm.recv(recive_buffer, src=0) for i, mem in enumerate(mem_managers): if i == cur_device_index: mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index) @@ -93,7 +102,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch. return def send_to_decode_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): """ 使用 p2p triton kernel 进行数据复制和传输的实现方式。 @@ -120,7 +133,7 @@ def send_to_decode_node_p2p( move_buffer = self._get_kv_move_data_p2p( move_token_indexes, token_dp_indexes, layer_index, self.kv_move_buffer, dp_size_in_node ) - dist.send(move_buffer, dst=1) + nccl_comm.send(move_buffer, dst=1) return def _get_kv_move_data_p2p( @@ -145,7 +158,11 @@ def _get_kv_move_data_p2p( return move_buffer def receive_from_prefill_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): if not hasattr(self, "mem_ptrs_dict"): self.mem_ptrs_dict = {} @@ -170,7 +187,7 @@ def receive_from_prefill_node_p2p( move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim) for layer_index in range(self.layer_num): - dist.recv(recive_buffer, src=0) + nccl_comm.recv(recive_buffer, src=0) self._write_kv_move_data_p2p( move_token_indexes, token_dp_indexes, recive_buffer, layer_index, dp_size_in_node ) diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 5e701effa..aae7112ff 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -10,6 +10,7 @@ from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.utils.dist_utils import get_current_rank_in_node from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.distributed.pynccl import PyNcclCommunicator logger = init_logger(__name__) @@ -91,7 +92,11 @@ def alloc_kv_move_buffer(self, max_req_total_len): return def send_to_decode_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -108,14 +113,14 @@ def send_to_decode_node( for layer_index in range(mem.layer_num): move_buffer = mem._get_kv_move_data(move_token_indexes, layer_index) if i == cur_device_index: - dist.send(move_buffer, dst=1) + nccl_comm.send(move_buffer, dst=1) else: move_size = move_buffer.numel() new_move_buffer = cur_mem.kv_move_buffer.view(-1)[0:move_size].view(move_buffer.shape) from torch.cuda import comm comm.broadcast(move_buffer, out=[new_move_buffer]) - dist.send(new_move_buffer, dst=1) + nccl_comm.send(new_move_buffer, dst=1) return def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): @@ -127,7 +132,11 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): return move_buffer def receive_from_prefill_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -144,7 +153,7 @@ def receive_from_prefill_node( recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, 2 * self.head_num, self.head_dim) for i, mem in enumerate(mem_managers): for layer_index in range(mem.layer_num): - dist.recv(recive_buffer, src=0) + nccl_comm.recv(recive_buffer, src=0) if i == cur_device_index: mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index) else: @@ -160,7 +169,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch. return def send_to_decode_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): """ 使用 p2p triton kernel 进行数据复制和传输的实现方式。 @@ -178,7 +191,7 @@ def send_to_decode_node_p2p( for i, mem in enumerate(mem_managers): for layer_index in range(mem.layer_num): move_buffer = mem._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer) - dist.send(move_buffer, dst=1) + nccl_comm.send(move_buffer, dst=1) return def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor): @@ -191,7 +204,11 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k return move_buffer def receive_from_prefill_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -209,7 +226,7 @@ def receive_from_prefill_node_p2p( recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, 2 * self.head_num, self.head_dim) for i, mem in enumerate(mem_managers): for layer_index in range(mem.layer_num): - dist.recv(recive_buffer, src=0) + nccl_comm.recv(recive_buffer, src=0) mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index) return diff --git a/lightllm/distributed/pynccl.py b/lightllm/distributed/pynccl.py new file mode 100644 index 000000000..b96e0d1ba --- /dev/null +++ b/lightllm/distributed/pynccl.py @@ -0,0 +1,285 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/pynccl.py +# of the vllm-project/vllm GitHub repository. +# +# Copyright 2023 ModelTC Team +# Copyright 2023 vLLM Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from datetime import timedelta +import pickle +import time +from typing import Optional, Union, Dict, Deque, Tuple, Any +from collections import deque +import logging + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp, TCPStore + +from lightllm.distributed.pynccl_wrapper import ( + NCCLLibrary, + buffer_type, + cudaStream_t, + ncclComm_t, + ncclDataTypeEnum, + ncclRedOpTypeEnum, + ncclUniqueId, +) + +logger = logging.getLogger(__name__) + +_current_stream = None + + +def current_stream() -> torch.cuda.Stream: + global _current_stream + if _current_stream is None: + _current_stream = torch.cuda.current_stream() + return _current_stream + + +@dataclasses.dataclass +class StatelessP2PProcessGroup: + """A dataclass to hold a metadata store, and the rank, world_size of the + group. Only use it to communicate metadata between processes. + For data-plane communication, create NCCL-related objects. + """ + + dest_id: int + src_id: int + is_server: bool + + rank: int = 0 + world_size: int = 2 + store: TCPStore = None + data_expiration_seconds: int = 3600 # 1 hour + # dst rank -> counter + send_dst_counter: int = 0 + # src rank -> counter + recv_src_counter: int = 0 + entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque) + + def __post_init__(self): + self.rank = 0 if self.is_server else 1 + self.world_size = 2 + self.send_dst_counter = 0 + self.recv_src_counter = 0 + + def send_obj(self, obj: Any): + """Send an object to a destination rank.""" + self.expire_data() + key = f"send_to/{self.dest_id}/{self.send_dst_counter}" + self.store.set(key, pickle.dumps(obj)) + self.send_dst_counter += 1 + self.entries.append((key, time.time())) + + def expire_data(self): + """Expire data that is older than `data_expiration_seconds` seconds.""" + while self.entries: + # check the oldest entry + key, timestamp = self.entries[0] + if time.time() - timestamp > self.data_expiration_seconds: + self.store.delete_key(key) + self.entries.popleft() + else: + break + + def recv_obj(self) -> Any: + """Receive an object from a source rank.""" + obj = pickle.loads(self.store.get(f"send_to/{self.dest_id}/{self.recv_src_counter}")) + self.recv_src_counter += 1 + return obj + + @staticmethod + def create( + src_id: int, dest_id: int, is_server: bool, store: torch._C._distributed_c10d.Store + ) -> "StatelessP2PProcessGroup": + """A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. + + If we have process A and process B called `torch.distributed.init_process_group` + to form a group, and then we want to form another group with process A, B, C, + D, it is not possible in PyTorch, because process A and process B have already + formed a group, and process C and process D cannot join that group. This + function is a workaround for this issue. + + `torch.distributed.init_process_group` is a global call, while this function + is a stateless call. It will return a `StatelessProcessGroup` object that can be + used for exchanging metadata. With this function, process A and process B + can call `StatelessProcessGroup.create` to form a group, and then process A, B, + C, and D can call `StatelessProcessGroup.create` to form another group. + """ # noqa + return StatelessP2PProcessGroup(src_id=src_id, dest_id=dest_id, is_server=is_server, store=store) + + +class PyNcclCommunicator: + def __init__( + self, + group: Union[ProcessGroup, StatelessP2PProcessGroup], + device: Union[int, str, torch.device], + library_path: Optional[str] = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyNcclCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + if not isinstance(group, StatelessP2PProcessGroup): + assert dist.is_initialized() + assert ( + dist.get_backend(group) != dist.Backend.NCCL + ), "PyNcclCommunicator should be attached to a non-NCCL group." + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + + self.group = group + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + return + + self.available = True + self.disabled = False + + logger.info("LightLLM is using nccl==%s", self.nccl.ncclGetVersion()) + + if self.rank == 0: + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() + else: + # construct an empty unique id + self.unique_id = ncclUniqueId() + + if not isinstance(group, StatelessP2PProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + if group.rank == 0: + group.send_obj(self.unique_id) + else: + self.unique_id = group.recv_obj() + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): + self.comm: ncclComm_t = self.nccl.ncclCommInitRank(self.world_size, self.unique_id, self.rank) + + stream = current_stream() + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + stream.synchronize() + del data + + def destroy(self): + self.nccl.ncclCommDestroy(self.comm) + + def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor: + if self.disabled: + return None + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert in_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {in_tensor.device}" + ) + + out_tensor = torch.empty_like(in_tensor) + + if stream is None: + stream = current_stream() + self.nccl.ncclAllReduce( + buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + return out_tensor + + def send(self, tensor: torch.Tensor, dst: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + self.nccl.ncclSend( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + dst, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def recv(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + self.nccl.ncclRecv( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) diff --git a/lightllm/distributed/pynccl_wrapper.py b/lightllm/distributed/pynccl_wrapper.py new file mode 100644 index 000000000..344689d96 --- /dev/null +++ b/lightllm/distributed/pynccl_wrapper.py @@ -0,0 +1,424 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/pynccl_wrapper.py +# of the vllm-project/vllm GitHub repository. +# +# Copyright 2023 ModelTC Team +# Copyright 2023 vLLM Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 + +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +import logging + +logger = logging.getLogger(__name__) + + +def find_nccl_library() -> str: + """ + We either use the library file specified by the `VLLM_NCCL_SO_PATH` + environment variable, or we find the library file brought by PyTorch. + After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be + found by `ctypes` automatically. + """ + so_file = None + + # manually load the nccl library + if so_file: + logger.info("Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file) + else: + if torch.version.cuda is not None: + so_file = "libnccl.so.2" + elif torch.version.hip is not None: + so_file = "librccl.so.1" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.info("Found nccl from library %s", so_file) + return so_file + + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function( + "ncclCommInitRank", ncclResult_t, [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int] + ), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclAllReduce", + ncclResult_t, + [buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t], + ), + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclAllGather", + ncclResult_t, + [buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ncclComm_t, cudaStream_t], + ), + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclReduceScatter", + ncclResult_t, + [buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t], + ), + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function( + "ncclSend", + ncclResult_t, + [buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t], + ), + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function( + "ncclRecv", + ncclResult_t, + [buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t], + ), + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function( + "ncclBroadcast", + ncclResult_t, + [buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t], + ), + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load NCCL library from %s. " + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s. " + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", + so_file, + platform.platform(), + ) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs: Dict[str, Any] = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), world_size, unique_id, rank)) + return comm + + def ncclAllReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, datatype, op, comm, stream)) + + def ncclReduceScatter( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, count, datatype, op, comm, stream)) + + def ncclAllGather( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, datatype, comm, stream)) + + def ncclSend( + self, sendbuff: buffer_type, count: int, datatype: int, dest: int, comm: ncclComm_t, stream: cudaStream_t + ) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream)) + + def ncclRecv( + self, recvbuff: buffer_type, count: int, datatype: int, src: int, comm: ncclComm_t, stream: cudaStream_t + ) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)) + + def ncclBroadcast( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, datatype, root, comm, stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", + "ncclDataTypeEnum", + "ncclRedOpTypeEnum", + "ncclUniqueId", + "ncclComm_t", + "cudaStream_t", + "buffer_type", +] + + +def test_ncclGetUniqueId(): + lib = NCCLLibrary() + unique_id = lib.ncclGetUniqueId() + print(unique_id.internal) + # `list(unique_id.internal)` is something like this: + # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + # as long as the function doesn't raise an exception, we're good + assert unique_id is not None + + +if __name__ == "__main__": + import torch + + torch.cuda.set_device(0) + test_ncclGetUniqueId() diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index d5d22c8ea..22405867c 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -75,6 +75,28 @@ class DecodeNodeInfo: max_new_tokens: int +@dataclass +class PDTransJoinInfo: + decode_id: int + decode_device_id: int + prefill_id: int + prefill_device_id: int + pd_prefill_nccl_ip: str + pd_prefill_nccl_port: int + # 用于标识一次唯一的连接,prefill_id 和 decode_id 相同时,可能因为网络原因重连,为了更好的区分 + # 一次连接,使用一个 uuid 为其标识 + connect_id: str + + +@dataclass +class PDTransLeaveInfo: + decode_id: int + prefill_id: int + # 用于标识一次唯一的连接,prefill_id 和 decode_id 相同时,可能因为网络原因重连,为了更好的区分 + # 一次连接,使用一个 uuid 为其标识 + connect_id: str + + @dataclass class KVMoveTask: group_request_id: int @@ -90,6 +112,8 @@ class KVMoveTask: prefill_dp_index: int decode_dp_index: int mark_start_time: float = None + # 标记任务使用某个连接id进行传输 + connect_id: str = None def __post_init__(self): if len(self.input_tokens) <= 0: @@ -102,14 +126,14 @@ def to_prefill_log_info(self): d_i = self.prefill_dp_index id = self.group_request_id log = f"id: {id} in_len:{len(self.input_tokens)} v_len: {v_len} move_len: {self.move_kv_len} dp_index:{d_i}" - return log + return log + f" connect_id: {self.connect_id}" def to_decode_log_info(self): v_len = None if self.decode_token_indexes is None else len(self.decode_token_indexes) d_i = self.decode_dp_index id = self.group_request_id log = f"id: {id} in_len:{len(self.input_tokens)} v_len: {v_len} move_len: {self.move_kv_len} dp_index:{d_i}" - return log + return log + f" connect_id: {self.connect_id}" def id(self): return self.group_request_id @@ -119,3 +143,9 @@ def get_cost_time(self): return time.time() - self.mark_start_time else: return 100000000000 + + +@dataclass +class KVMoveTaskGroup: + tasks: List[KVMoveTask] + connect_id: str diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 51411c301..5424c12da 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -98,6 +98,7 @@ def __init__(self, args, router_port, detokenization_port, metric_port): self.stats_tool = Stats(not args.disable_log_stats, args.log_stats_interval) self.metric_client = MetricClient(metric_port) self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode"] + self.is_pd_decode_mode = self.args.run_mode == "decode" # p d 分离模式下,需要调度锁来同步调度端和推理端的一些数据操作 # 主要是为了防止调度失误,造成 OOM 等错误 self.router_lock = mp.Lock() @@ -240,14 +241,17 @@ async def loop_for_fwd( ) / self.max_total_token_num d_i = dp_index frozen_token_num = self.shared_token_load.get_frozened_token_count(d_i) + estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i) logger.debug( f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n" f"dp_i {d_i} paused req num: {self.req_queue.get_paused_req_num()} \n" f"dp_i {d_i} frozen token num: {frozen_token_num} \n" + f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n" f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n" f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token" ) - self.req_queue.update_token_load(self.running_batch, force_update=False) + # pd decode mode need to update token_load more frequently + self.req_queue.update_token_load(self.running_batch, force_update=self.is_pd_decode_mode) self.stats_tool.print_stats() self.metric_client.gauge_set("lightllm_batch_current_size", len(self.running_batch.reqs)) self.metric_client.gauge_set("lightllm_batch_pause_size", self.req_queue.get_paused_req_num()) @@ -270,7 +274,9 @@ async def loop_for_fwd( if log_time_ready("frozen_info", 60): for dp_i in range(self.dp_size_in_node): frozen_token_num = self.shared_token_load.get_frozened_token_count(dp_i) + estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(dp_i) logger.debug(f"dp_i {dp_i} frozen token num: {frozen_token_num} \n") + logger.debug(f"dp_i {dp_i} estimated_peak_token_count: {estimated_peak_token_count} \n") if self.running_batch is None: await asyncio.sleep(0.01) # 10ms diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py index 0b161f5de..8f88237ec 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py @@ -71,6 +71,20 @@ def recover_frozen_token(self, key_len, max_new_token): def _alloc_to_frozen_some_tokens(self, move_task: KVMoveTask): is_ok = self.judge_token_is_ok(len(move_task.input_tokens), move_task.decode_node.max_new_tokens) if not is_ok: + if self.is_master_in_dp: + logger.info(f"req_id: {move_task.to_decode_log_info()} alloc token failed") + shared_token_load = self.backend.shared_token_load + dp_rank = self.dp_rank_in_node + frozen_token_num = shared_token_load.get_frozened_token_count(dp_rank) + estimated_peak_token_num = shared_token_load.get_estimated_peak_token_count(dp_rank) + logger.debug( + f"radix refed token num {self.backend.radix_cache.get_refed_tokens_num()}\n" + f"radix hold token num {self.backend.radix_cache.get_tree_total_tokens_num()}\n" + f"mem manager can alloc token num {self.backend.model.mem_manager.can_use_mem_size}\n" + f"mem manager total size {self.backend.model.mem_manager.size}" + f"frozened token num {frozen_token_num}\n" + f"estimated peak token num {estimated_peak_token_num}\n" + ) return None key = torch.tensor(move_task.input_tokens, dtype=torch.int64, device="cpu") diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py index 30096e3e5..457fd1b9c 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py @@ -16,7 +16,7 @@ from .decode_infer_rpyc import PDDecodeInferRpcServer from ..task_queue import TaskQueue import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import KVMoveTask, UpKVStatus +from lightllm.server.pd_io_struct import KVMoveTask, UpKVStatus, PDTransJoinInfo, PDTransLeaveInfo from lightllm.utils.retry_utils import retry import numpy as np from rpyc import AsyncResult @@ -31,221 +31,6 @@ KV_MOVE_MAX_NUM = 16 -@dataclass -class TransProcessObj: - prefill_node_id: str = None - process: mp.Process = None - task_in_queue: mp.Queue = None - task_out_queue: mp.Queue = None - nccl_ip: str = None - nccl_port: str = None - device_index: int = None - manager: "DecodeKVMoveManager" = None - has_error: bool = False - ready_to_move_queue: TaskQueue = None - kv_move_thread: threading.Thread = None - move_finished_queue: TaskQueue = None - put_to_radix_thread: threading.Thread = None - latest_check_time: float = None - - def create(self, prefill_node_id: str, nccl_ip: str, nccl_port: int, manager: "DecodeKVMoveManager"): - from .decode_trans_process import start_decode_trans_process - - task_in_queue = mp.Queue() - task_out_queue = mp.Queue() - device_index = manager.get_next_device_index() - proc = start_decode_trans_process( - manager.args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, manager.mem_queues - ) - assert task_out_queue.get(timeout=30) == "proc_start" - manager._put_mem_manager_to_mem_queue() - assert task_out_queue.get(timeout=60) == "get_mem_managers_ok" - assert task_out_queue.get(timeout=60) == "nccl_ok" - - self.prefill_node_id = prefill_node_id - self.process = proc - self.task_in_queue = task_in_queue - self.task_out_queue = task_out_queue - self.nccl_ip = nccl_ip - self.nccl_port = nccl_port - self.device_index = device_index - - self.manager = manager - self.latest_check_time = time.time() - - self.ready_to_move_queue = TaskQueue( - get_func=lambda datas: datas[0:1], fail_func=self.manager.put_to_fail_release_task_queue - ) - self.kv_move_thread = threading.Thread(target=self.kv_move_loop, daemon=True) - self.kv_move_thread.start() - - self.move_finished_queue = TaskQueue( - get_func=lambda datas: datas[0:KV_MOVE_MAX_NUM], fail_func=self.manager.put_to_fail_release_task_queue - ) - self.put_to_radix_thread = threading.Thread(target=self.put_to_radix_loop, daemon=True) - self.put_to_radix_thread.start() - return - - def check_trans_process(self, raise_exception=True): - process = psutil.Process(self.process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - self.set_has_error() - if raise_exception: - raise Exception(f"trans process: {self.process.pid} is dead") - return - - def timer_to_check_status(self, raise_exception=True): - if time.time() - self.latest_check_time >= 2.0: - self.latest_check_time = time.time() - self.check_trans_process(raise_exception=raise_exception) - return - - def _transfer_kv(self, move_tasks: List[KVMoveTask]): - with self.manager.device_locks[self.device_index]: - self.task_in_queue.put(move_tasks.copy(), timeout=10) - assert self.task_out_queue.get(timeout=60) == "ok" - logger.info(f"_transfer_kv ok {move_tasks[0].to_decode_log_info()}") - - # 标记 decode 接收到 kv cache 的时间 - for move_task in move_tasks: - move_task.mark_start_time = time.time() - - self.move_finished_queue.put_list(move_tasks) - move_tasks.clear() - - def kv_move_loop(self): - func_name = self.kv_move_loop.__name__ - while not self.has_error: - move_tasks: List[List[KVMoveTask]] = self.ready_to_move_queue.get_tasks(log_tag="ready_to_move_queue") - if len(move_tasks) == 0: - time.sleep(0.01) - continue - - if len(move_tasks) != 1: - logger.error(f"error get need 1, but get {len(move_tasks)}") - assert False - - move_tasks = move_tasks[0] - for task in move_tasks: - logger.info(f"{func_name} get task {task.to_decode_log_info()}") - - try: - self.timer_to_check_status(raise_exception=True) - - if not kv_trans_use_p2p(): - with self.manager.kv_trans_lock: - self._transfer_kv(move_tasks) - else: - self._transfer_kv(move_tasks) - - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.ready_to_move_queue.clear_tasks() - self.manager.remove_trans_obj(self.prefill_node_id) - - finally: - self.manager.put_to_fail_release_task_queue(move_tasks) - - logger.error(f"{func_name} prefill id {self.prefill_node_id} device_index {self.device_index} thread quit") - return - - def put_to_radix_loop(self): - func_name = self.put_to_radix_loop.__name__ - while not self.has_error: - move_tasks: List[KVMoveTask] = self.move_finished_queue.get_tasks(log_tag="move_finished_queue") - if len(move_tasks) == 0: - time.sleep(0.01) - continue - - for task in move_tasks: - logger.info(f"{func_name} get put radix task {task.to_decode_log_info()}") - - try: - # random to check stats - self.timer_to_check_status(raise_exception=True) - - self.manager._put_kv_received_to_radix_cache(move_tasks.copy()) - for task in move_tasks.copy(): - logger.info( - f"{func_name} put kv to radix cache ok, req_id: {task.id()} cost_time {task.get_cost_time()} s" - ) - self.manager.up_status_in_queue.put( - UpKVStatus(group_request_id=task.group_request_id, dp_index=task.decode_dp_index) - ) - logger.info(f"{func_name} up kv status req_id: {task.id()} finished") - move_tasks.clear() - - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.move_finished_queue.clear_tasks() - self.manager.remove_trans_obj(self.prefill_node_id) - - finally: - self.manager.put_to_fail_release_task_queue(move_tasks) - - logger.error(f"{func_name}, prefill id {self.prefill_node_id} device_index {self.device_index} thread quit") - return - - def wait_thread_quit(self): - if self.kv_move_thread is not None: - if self.kv_move_thread.is_alive(): - try: - self.kv_move_thread.join() - except: - pass - if self.put_to_radix_thread is not None: - if self.put_to_radix_thread.is_alive(): - try: - self.put_to_radix_thread.join() - except: - pass - return - - def has_error_status(self): - try: - assert self.has_error is False - assert self.kv_move_thread.is_alive() - assert self.put_to_radix_thread.is_alive() - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - return True - - return False - - def set_has_error(self): - self.has_error = True - try: - self.ready_to_move_queue.has_error = True - self.move_finished_queue.has_error = True - except: - pass - return - - def __del__(self): - logger.error(f"trans obj del start, prefill node id {self.prefill_node_id} device_index {self.device_index}") - - try: - self.set_has_error() - self.wait_thread_quit() - if self.ready_to_move_queue is not None: - self.ready_to_move_queue.clear_tasks() - if self.move_finished_queue is not None: - self.move_finished_queue.clear_tasks() - except BaseException as e: - logger.exception(str(e)) - - logger.error(f"trans obj deled, prefill node id {self.prefill_node_id} device_index {self.device_index}") - - # 强制关闭连接和杀掉传输进程 - if self.process is not None: - logger.warning(f"trans kv process {self.process.pid} is killed") - os.kill(self.process.pid, signal.SIGKILL) - pass - - class DecodeKVMoveManager(rpyc.Service): def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): super().__init__() @@ -261,7 +46,10 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.mem_queues = mem_queues self.infer_rpyc_lock = threading.Lock() self.infer_rpyc_objs: List[PDDecodeInferRpcServer] = [] - self.node_id_to_trans_obj: Dict[str, TransProcessObj] = {} + + from .decode_trans_obj import KVTransConnectObj + + self.connect_id_to_trans_obj: Dict[str, KVTransConnectObj] = {} for port in self.args.pd_node_infer_rpyc_ports: socket_path = f"/tmp/{get_unique_server_name()}_decode_node_infer_rpyc_{port}" from rpyc.utils.factory import unix_connect @@ -281,29 +69,28 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.fail_to_release_thread = threading.Thread(target=self.handle_fail_release_task_loop, daemon=True) self.fail_to_release_thread.start() + # 在不使用p2p 复制kv 的方案时,需要全局的传输锁进行控制。这个时候kv传输的效率会下降。 self.kv_trans_lock = threading.Lock() - # 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。 - self.device_locks = [threading.Lock() for _ in range(self.node_world_size)] - return - def put_to_fail_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): - if isinstance(task, KVMoveTask): - self.fail_to_release_queue.put(task) - elif isinstance(task, list): - self.fail_to_release_queue.put_list(task) - else: - assert False, "error input" - return + from .decode_trans_obj import KVTransProcess + + self.kv_trans_processes: List[KVTransProcess] = [None] * self.node_world_size + for device_id in range(self.node_world_size): + self.kv_trans_processes[device_id] = KVTransProcess() + assert self.kv_trans_processes[device_id].init_all(device_id, self) - def handle_fail_release_task_loop(self): - while True: - handle_list: List[KVMoveTask] = self.fail_to_release_queue.get_tasks(log_tag="fail_to_release_queue") - if len(handle_list) == 0: - time.sleep(0.01) - else: - self._fail_to_realese_forzen_tokens(handle_list) return + # ================================================================================== + # _dp_alloc_to_frozen_some_tokens + # _put_kv_received_to_radix_cache + # _fail_to_realese_forzen_tokens + # _unfrozen_time_out_reqs_tokens + # _put_mem_manager_to_mem_queue + # 上述接口都是 kv move manager 与推理进程进行交互的接口,主要用于申请锁定kv资源或者释放 + # kv资源的接口 + # ================================================================================== + async def wait_all_future_finish(self, futures: List[AsyncResult]): await asyncio.gather(*[asyncio.to_thread(future.wait) for future in futures]) return @@ -373,16 +160,50 @@ def _put_mem_manager_to_mem_queue(self) -> None: obj.put_mem_manager_to_mem_queue() return + # ================================================================================== + # put_to_fail_release_task_queue 将因为一些原因失败,需要释放锁定的kv资源的请求放入到 + # 对应的处理队列中,handle_fail_release_task_loop 是一个循环的线程,专门处理这些失败的请求 + # 通过调用与推理进程交互的接口,释放掉申请锁定的 kv 资源。 + # ================================================================================== + + def put_to_fail_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): + if isinstance(task, KVMoveTask): + self.fail_to_release_queue.put(task) + elif isinstance(task, list): + self.fail_to_release_queue.put_list(task) + else: + assert False, "error input" + return + + def handle_fail_release_task_loop(self): + while True: + handle_list: List[KVMoveTask] = self.fail_to_release_queue.get_tasks(log_tag="fail_to_release_queue") + if len(handle_list) == 0: + time.sleep(0.01) + else: + self._fail_to_realese_forzen_tokens(handle_list) + return + + # ================================================================================== + # on_connect + # on_disconnect + # exposed_check_alive + # exposed_build_trans_process + # exposed_request_data_transfer + # 上述接口是decode kv move manager 暴露的 rpyc 调用接口,用于 prefill kv move manager + # 进行连接,进行一些元数据资源的交互。 + # ================================================================================== + def on_connect(self, conn): # 用于处理连接断开的时候,自动删除资源 - thread_local_data.prefill_node_id = None + thread_local_data.connect_id = None pass def on_disconnect(self, conn): # 用于处理连接断开的时候,自动删除资源 - if thread_local_data.prefill_node_id is not None: - self.remove_trans_obj(thread_local_data.prefill_node_id) - logger.info(f"prefill node id {thread_local_data.prefill_node_id} disconnect") + if thread_local_data.connect_id is not None: + self.remove_trans_obj(thread_local_data.connect_id) + logger.info(f"connect id {thread_local_data.connect_id} disconnect") import gc gc.collect() @@ -392,18 +213,22 @@ def exposed_check_alive(self): # 用于 prefill node check 通信连接的状态。 return - def exposed_build_trans_process(self, prefill_node_id, nccl_ip, nccl_port, prefill_node_max_kv_trans_num): - prefill_node_id, nccl_ip, nccl_port, prefill_node_max_kv_trans_num = list( - map(obtain, [prefill_node_id, nccl_ip, nccl_port, prefill_node_max_kv_trans_num]) + def exposed_build_trans_connect( + self, prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num, connect_id + ): + prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num = list( + map(obtain, [prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num]) ) - thread_local_data.prefill_node_id = prefill_node_id - - logger.info(f"build trans infos {prefill_node_id} {nccl_ip} {nccl_port}") - # 如果有历史残留,一并移除 - self.remove_trans_obj(prefill_node_id) - tran_obj = TransProcessObj() - tran_obj.create(prefill_node_id, nccl_ip, nccl_port, self) - self.node_id_to_trans_obj[prefill_node_id] = tran_obj + connect_id = obtain(connect_id) + thread_local_data.connect_id = connect_id + + logger.info(f"build trans infos {prefill_node_id} {pd_prefill_nccl_ip} {pd_prefill_nccl_port} {connect_id}") + + from .decode_trans_obj import KVTransConnectObj + + tran_obj = KVTransConnectObj() + tran_obj.create(connect_id, prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, self) + self.connect_id_to_trans_obj[connect_id] = tran_obj return min(prefill_node_max_kv_trans_num, self.args.max_total_token_num) # 返回 None 代表繁忙, 放弃该任务的 kv 传送 @@ -451,59 +276,90 @@ def exposed_request_data_transfer(self, tasks: List[KVMoveTask]) -> List[Optiona except BaseException as e: self.put_to_fail_release_task_queue(alloc_tokened_tasks) alloc_tokened_tasks = [] - self.remove_trans_obj(tasks[0].prefill_node_id) + self.remove_trans_obj(tasks[0].connect_id) logger.exception(str(e)) raise e + if alloc_tokened_tasks: + trans_obj.ready_to_move_queue.put( + alloc_tokened_tasks, error_handle_func=self.put_to_fail_release_task_queue + ) + + return ans_list + + # ================================================================================== + # 定时检测kv 传输成功,但是长时间没有pd master来触发推理的请求, + # 释放这些超时请求占用的kv资源 + # ================================================================================== + + def timer_loop(self): try: - if len(alloc_tokened_tasks) != 0: - trans_obj.ready_to_move_queue.put(alloc_tokened_tasks) - except BaseException as e: + while True: + self._unfrozen_time_out_reqs_tokens() + time.sleep(3.5) + except (BaseException, RuntimeError) as e: logger.exception(str(e)) - self.put_to_fail_release_task_queue(alloc_tokened_tasks) - alloc_tokened_tasks = [] raise e - return ans_list + # ================================================================================== + # 定时检测传输进程的健康状态,出现问题拉崩整个系统触发重启 + # ================================================================================== + + def check_trans_process_loop(self): + try: + while True: + for device_id in range(self.node_world_size): + if not self.kv_trans_processes[device_id].is_trans_process_health(): + raise Exception(f"device_id {device_id} kv process is unhealth") + + time.sleep(10.0) + except (BaseException, RuntimeError) as e: + logger.exception(str(e)) + for device_id in range(self.node_world_size): + self.kv_trans_processes[device_id].killself() + + # 杀掉当前进程的父进程(router), 触发全局崩溃 + os.kill(os.getppid(), signal.SIGKILL) + os.kill(os.getpid(), signal.SIGKILL) + raise e + + # ================================================================================== + # 常用辅助功能函数 + # ================================================================================== def get_next_device_index(self): counts = [0 for _ in range(self.node_world_size)] - for obj in self.node_id_to_trans_obj.values(): + for obj in self.connect_id_to_trans_obj.values(): counts[obj.device_index] += 1 device_index = int(np.argmin(counts)) return device_index def get_trans_obj(self, task: KVMoveTask): - self.remove_dead_trans_obj() - return self.node_id_to_trans_obj[task.prefill_node_id] + self.__remove_dead_trans_obj() + return self.connect_id_to_trans_obj[task.connect_id] - def remove_dead_trans_obj(self): - del_node_ids = [] - for node_id, t_obj in self.node_id_to_trans_obj.items(): + def __remove_dead_trans_obj(self): + del_connect_ids = [] + for connect_id, t_obj in self.connect_id_to_trans_obj.items(): if t_obj.has_error_status(): - del_node_ids.append(node_id) + del_connect_ids.append(connect_id) - for node_id in del_node_ids: - self.node_id_to_trans_obj.pop(node_id, None) + for connect_id in del_connect_ids: + self.connect_id_to_trans_obj.pop(connect_id, None) - if len(del_node_ids) != 0: + if del_connect_ids: import gc gc.collect() return - def remove_trans_obj(self, prefill_node_id): - if prefill_node_id in self.node_id_to_trans_obj: - trans_obj = self.node_id_to_trans_obj.pop(prefill_node_id, None) + def remove_trans_obj(self, connect_id): + if connect_id in self.connect_id_to_trans_obj: + trans_obj = self.connect_id_to_trans_obj.pop(connect_id, None) if trans_obj is not None: trans_obj.set_has_error() return - def timer_loop(self): - while True: - self._unfrozen_time_out_reqs_tokens() - time.sleep(3.5) - def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event): import lightllm.utils.rpyc_fix_utils as _ @@ -515,6 +371,9 @@ def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp. t = ThreadedServer(manager, port=args.pd_decode_rpyc_port, protocol_config={"allow_pickle": True}) threading.Thread(target=lambda: t.start(), daemon=True).start() + kv_trans_process_check = threading.Thread(target=manager.check_trans_process_loop, daemon=True) + kv_trans_process_check.start() + event.set() manager.timer_loop() return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py new file mode 100644 index 000000000..fd42b3772 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py @@ -0,0 +1,303 @@ +import time +import psutil +import threading +from typing import List +from dataclasses import dataclass +from lightllm.utils.log_utils import init_logger +from ..task_queue import TaskQueue +import torch.multiprocessing as mp +from lightllm.server.pd_io_struct import KVMoveTask, UpKVStatus, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup +from lightllm.utils.device_utils import kv_trans_use_p2p +from .decode_kv_move_manager import DecodeKVMoveManager +from lightllm.utils.time_utils import TimeChecker +from ..utils import join_if_alive, clear_queue + +logger = init_logger(__name__) + +KV_MOVE_MAX_NUM = 16 + + +@dataclass +class KVTransConnectObj: + connect_id: str = None + prefill_node_id: int = None + kv_trans_process: "KVTransProcess" = None + pd_prefill_nccl_ip: str = None + pd_prefill_nccl_port: int = None + device_index: int = None + manager: "DecodeKVMoveManager" = None + has_error: bool = False + ready_to_move_queue: TaskQueue = None + kv_move_thread: threading.Thread = None + move_finished_queue: TaskQueue = None + put_to_radix_thread: threading.Thread = None + timer_checker: TimeChecker = None + + def create( + self, + connect_id: str, + prefill_node_id: str, + pd_prefill_nccl_ip: str, + pd_prefill_nccl_port: int, + manager: "DecodeKVMoveManager", + ): + self.connect_id = connect_id + self.device_index = manager.get_next_device_index() + self.kv_trans_process = manager.kv_trans_processes[self.device_index] + decode_node_id = manager.args.pd_node_id + self.prefill_node_id = prefill_node_id + self.decode_node_id = decode_node_id + self.pd_prefill_nccl_ip = pd_prefill_nccl_ip + self.pd_prefill_nccl_port = pd_prefill_nccl_port + + self.manager = manager + self.timer_checker = TimeChecker(6) + + with self.kv_trans_process.device_lock: + clear_queue(self.kv_trans_process.task_out_queue) + self.kv_trans_process.task_in_queue.put( + PDTransJoinInfo( + prefill_id=prefill_node_id, + prefill_device_id=-1, + pd_prefill_nccl_ip=pd_prefill_nccl_ip, + pd_prefill_nccl_port=pd_prefill_nccl_port, + decode_id=decode_node_id, + decode_device_id=self.device_index, + connect_id=self.connect_id, + ) + ) + assert self.kv_trans_process.task_out_queue.get(timeout=60) == "nccl_ok" + + self.ready_to_move_queue = TaskQueue( + get_func=lambda datas: datas[0:1], fail_func=self.manager.put_to_fail_release_task_queue + ) + self.kv_move_thread = threading.Thread(target=self.kv_move_loop, daemon=True) + self.kv_move_thread.start() + + self.move_finished_queue = TaskQueue( + get_func=lambda datas: datas[0:KV_MOVE_MAX_NUM], fail_func=self.manager.put_to_fail_release_task_queue + ) + self.put_to_radix_thread = threading.Thread(target=self.put_to_radix_loop, daemon=True) + self.put_to_radix_thread.start() + return + + # ================================================================================== + # 处理接受所有进行 kv 传输的请求,完成后,将请求放入到 move_finished_queue 中 + # ================================================================================== + + def _transfer_kv(self, move_tasks: List[KVMoveTask]): + with self.kv_trans_process.device_lock: + clear_queue(self.kv_trans_process.task_out_queue) + kv_move_group = KVMoveTaskGroup(tasks=move_tasks.copy(), connect_id=self.connect_id) + kv_move_group.connect_id = self.connect_id + self.kv_trans_process.task_in_queue.put(kv_move_group, timeout=10) + assert self.kv_trans_process.task_out_queue.get(timeout=60) == "ok" + logger.info(f"_transfer_kv ok {move_tasks[0].to_decode_log_info()}") + + # 标记 decode 接收到 kv cache 的时间 + for move_task in move_tasks: + move_task.mark_start_time = time.time() + + self.move_finished_queue.put_list(move_tasks) + move_tasks.clear() + + def kv_move_loop(self): + func_name = self.kv_move_loop.__name__ + while not self.has_error: + move_tasks: List[List[KVMoveTask]] = self.ready_to_move_queue.get_tasks(log_tag="ready_to_move_queue") + if len(move_tasks) == 0: + time.sleep(0.01) + continue + + if len(move_tasks) != 1: + logger.error(f"error get need 1, but get {len(move_tasks)}") + assert False + + move_tasks: List[KVMoveTask] = move_tasks[0] + for task in move_tasks: + logger.info(f"{func_name} get task {task.to_decode_log_info()}") + + try: + self.timer_to_check_status(raise_exception=True) + if not kv_trans_use_p2p(): + with self.manager.kv_trans_lock: + self._transfer_kv(move_tasks) + else: + self._transfer_kv(move_tasks) + + except BaseException as e: + logger.exception(str(e)) + self.set_has_error() + self.ready_to_move_queue.clear_tasks() + + finally: + self.manager.put_to_fail_release_task_queue(move_tasks) + + logger.error(f"{func_name} thread quit") + return + + # ================================================================================== + # 将传输完成的请求,放入到 radix cache 中进行管理。 + # ================================================================================== + + def put_to_radix_loop(self): + func_name = self.put_to_radix_loop.__name__ + while not self.has_error: + move_tasks: List[KVMoveTask] = self.move_finished_queue.get_tasks(log_tag="move_finished_queue") + if len(move_tasks) == 0: + time.sleep(0.01) + continue + + for task in move_tasks: + logger.info(f"{func_name} get put radix task {task.to_decode_log_info()}") + + try: + self.timer_to_check_status(raise_exception=True) + # random to check stats + self.manager._put_kv_received_to_radix_cache(move_tasks.copy()) + for task in move_tasks.copy(): + logger.info( + f"{func_name} put kv to radix cache ok, req_id: {task.id()} cost_time {task.get_cost_time()} s" + ) + self.manager.up_status_in_queue.put( + UpKVStatus(group_request_id=task.group_request_id, dp_index=task.decode_dp_index) + ) + logger.info(f"{func_name} up kv status req_id: {task.id()} finished") + move_tasks.clear() + + except BaseException as e: + logger.exception(str(e)) + self.set_has_error() + self.move_finished_queue.clear_tasks() + + finally: + self.manager.put_to_fail_release_task_queue(move_tasks) + + logger.error(f"{func_name} thread quit, info: {self.to_log_info()}") + return + + # ================================================================================== + # 错误处理检测操作的一些通用函数 + # ================================================================================== + + def timer_to_check_status(self, raise_exception=True): + if self.timer_checker.has_exceeded(): + try: + assert self.kv_trans_process.is_trans_process_health() + except BaseException as e: + logger.error(f"pid {self.kv_trans_process.process.pid} check failed") + logger.exception(str(e)) + + self.set_has_error() + if raise_exception: + raise e + return + + def has_error_status(self): + try: + assert self.has_error is False + assert self.kv_move_thread.is_alive() + assert self.put_to_radix_thread.is_alive() + except BaseException as e: + logger.exception(str(e)) + self.set_has_error() + return True + + return False + + def set_has_error(self): + self.has_error = True + + if self.ready_to_move_queue is not None: + self.ready_to_move_queue.has_error = True + + if self.move_finished_queue is not None: + self.move_finished_queue.has_error = True + + if self.manager is not None: + self.manager.remove_trans_obj(self.connect_id) + return + + def __del__(self): + logger.error(f"trans obj del start, info: {self.to_log_info()}") + + try: + self.set_has_error() + + join_if_alive(self.kv_move_thread) + join_if_alive(self.put_to_radix_thread) + + if self.connect_id is not None and self.kv_trans_process is not None: + self.kv_trans_process.task_in_queue.put( + PDTransLeaveInfo( + decode_id=self.decode_node_id, prefill_id=self.prefill_node_id, connect_id=self.connect_id + ) + ) + + if self.ready_to_move_queue is not None: + self.ready_to_move_queue.clear_tasks() + if self.move_finished_queue is not None: + self.move_finished_queue.clear_tasks() + + except BaseException as e: + logger.exception(str(e)) + + logger.error(f"trans obj deled, info: {self.to_log_info()}") + + def to_log_info(self): + log = f"connect_id: {self.connect_id} " + log += f"decode_node_id: {self.decode_node_id} " + log += f"prefill_node_id: {self.prefill_node_id} " + log += f"device_index: {self.device_index} " + return log + + +@dataclass +class KVTransProcess: + process: mp.Process = None + # 需要每个卡有一个锁来规划每次只能有一个 connection obj 操作对应显卡上的传输任务。 + device_lock: threading.Lock = None + task_in_queue: mp.Queue = None + task_out_queue: mp.Queue = None + device_id: int = None + + def init_all(self, device_id: int, manager: "DecodeKVMoveManager"): + self.device_lock = threading.Lock() + self.device_id = device_id + self.task_in_queue = mp.Queue() + self.task_out_queue = mp.Queue() + + try: + from .decode_trans_process import start_decode_trans_process + + self.process = start_decode_trans_process( + manager.args, + device_id, + self.task_in_queue, + self.task_out_queue, + manager.mem_queues, + ) + assert self.task_out_queue.get(timeout=30) == "proc_start" + manager._put_mem_manager_to_mem_queue() + assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" + + return True + + except Exception as e: + logger.warning(f"Failed start kv trans process for device {device_id}: {e}") + logger.exception(str(e)) + return False + + def is_trans_process_health(self): + try: + process = psutil.Process(self.process.pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + logger.error(f"kv trans process for device: {self.device_id} dead!!!") + return False + else: + return True + except: + return False + + def killself(self): + self.process.kill() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index b70bf8efe..782c95326 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -2,26 +2,94 @@ import time import sys import inspect +import threading import torch.multiprocessing as mp -from typing import List, Dict +from torch.distributed import TCPStore +from datetime import timedelta +from typing import List, Dict, Union from lightllm.utils.log_utils import init_logger from lightllm.common.mem_manager import MemoryManager -from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.graceful_utils import graceful_registry +from lightllm.distributed.pynccl import PyNcclCommunicator, StatelessP2PProcessGroup logger = init_logger(__name__) -def _init_env( - args, - device_index: int, - nccl_ip, - nccl_port, - task_in_queue: mp.Queue, +def _handle_kvmove_task( + move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, - mem_queues: List[mp.Queue], + mem_managers: List[MemoryManager], + connect_id_to_comm: Dict[str, PyNcclCommunicator], + connect_id: str, + dp_size_in_node: int, +): + total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) + try: + device_index = connect_id_to_comm[connect_id].device.index + start = time.time() + if total_move_kv_len != 0: + cur_mem = mem_managers[device_index] + logger.info(f"trans start: {move_tasks[0].to_decode_log_info()}") + if kv_trans_use_p2p(): + cur_mem.receive_from_prefill_node_p2p( + move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id] + ) + else: + cur_mem.receive_from_prefill_node( + move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id] + ) + logger.info(f"trans finished: {move_tasks[0].to_decode_log_info()} move len: {total_move_kv_len}") + torch.cuda.synchronize() + logger.info(f"trans cost time: {(time.time() - start)}, {move_tasks[0].to_decode_log_info()}") + task_out_queue.put("ok") + except BaseException as e: + logger.exception(str(e)) + task_out_queue.put("fail") + raise e + + +def _handle_prefill_join( + node_info: PDTransJoinInfo, task_out_queue: mp.Queue, connect_id_to_comm: Dict[str, PyNcclCommunicator] ): + try: + logger.info(f"connect start {node_info}") + store_client = TCPStore( + host_name=node_info.pd_prefill_nccl_ip, + port=node_info.pd_prefill_nccl_port, + is_master=False, + use_libuv=True, + timeout=timedelta(seconds=30), + ) + src_id = node_info.prefill_id + dest_id = node_info.connect_id + logger.info(f"connect src_id {src_id} dest_id {dest_id}") + + result_list = [] + + def async_connect(): + torch.cuda.set_device(node_info.decode_device_id) + group = StatelessP2PProcessGroup.create(src_id=src_id, dest_id=dest_id, is_server=False, store=store_client) + comm = PyNcclCommunicator(group, node_info.decode_device_id) + result_list.append(comm) + return + + connect_task = threading.Thread(target=async_connect, daemon=True) + connect_task.start() + connect_task.join(timeout=36) + if connect_task.is_alive(): + raise Exception(f"{node_info} connect time out") + + connect_id_to_comm[node_info.connect_id] = result_list[0] + logger.info(f"{node_info} kv trans connected") + task_out_queue.put("nccl_ok") + except Exception as e: + task_out_queue.put("nccl_fail") + logger.warning(f"error while connect to prefill node: {e}") + + +def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue]): import os # os.environ["NCCL_DEBUG"] = "INFO" @@ -31,63 +99,48 @@ def _init_env( torch.backends.cudnn.enabled = False dp_size_in_node = max(1, args.dp // args.nnodes) - node_world_size = args.tp // args.nnodes try: - # 注册graceful 退出的处理 + torch.cuda.set_device(device_id) graceful_registry(inspect.currentframe().f_code.co_name) - torch.cuda.set_device(device_index) - task_out_queue.put("proc_start") + mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] - assert len(mem_managers) == node_world_size - task_out_queue.put("get_mem_managers_ok") - import torch.distributed as dist - from datetime import timedelta - dist.init_process_group( - "nccl", init_method=f"tcp://{nccl_ip}:{nccl_port}", rank=1, world_size=2, timeout=timedelta(seconds=60) - ) - task_out_queue.put("nccl_ok") + task_out_queue.put("get_mem_managers_ok") + connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} while True: - move_tasks: List[KVMoveTask] = task_in_queue.get() - total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) - try: - start = time.time() - if total_move_kv_len != 0: - cur_mem = mem_managers[device_index] - logger.info(f"trans start: {move_tasks[0].to_decode_log_info()}") - if kv_trans_use_p2p(): - cur_mem.receive_from_prefill_node_p2p(move_tasks, mem_managers, dp_size_in_node) - else: - cur_mem.receive_from_prefill_node(move_tasks, mem_managers, dp_size_in_node) - logger.info(f"trans finished: {move_tasks[0].to_decode_log_info()} move len: {total_move_kv_len}") - torch.cuda.synchronize() - logger.info(f"trans cost time: {(time.time() - start)}, {move_tasks[0].to_decode_log_info()}") - task_out_queue.put("ok") - except BaseException as e: - logger.exception(str(e)) - task_out_queue.put("fail") - raise e - except BaseException as e: - logger.exception(str(e)) - sys.exit(-1) - return + task: Union[KVMoveTaskGroup, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() + if isinstance(task, KVMoveTaskGroup): + _handle_kvmove_task( + task.tasks, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node + ) + elif isinstance(task, PDTransJoinInfo): + _handle_prefill_join(task, task_out_queue, connect_id_to_comm) + elif isinstance(task, PDTransLeaveInfo): + if task.connect_id in connect_id_to_comm: + connect_id_to_comm[task.connect_id].destroy() + logger.info(f"destory {task} nccl communicator.") + else: + logger.info(f"no connect_id {task.connect_id} found in connect_id_to_comm") + + else: + logger.warning(f"unexpected task type: {task}") + + except Exception as e: + logger.error(f"Fatal error happened in kv trans process: {e}") + raise def start_decode_trans_process( args, - device_index: int, - nccl_ip, - nccl_port, + device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): - proc = mp.Process( - target=_init_env, args=(args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, mem_queues) - ) + proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue, mem_queues)) proc.start() assert proc.is_alive() - logger.info(f"decode trans kv process start, nccl_ip: {nccl_ip}, nccl_port: {nccl_port}") + logger.info(f"decode trans kv process for device: {device_id} start!") return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index 27b0fbb19..a54b54980 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -11,20 +11,15 @@ import threading import inspect import collections -from dataclasses import dataclass from typing import List, Dict, Union from lightllm.utils.log_utils import init_logger from .prefill_infer_rpyc import PDPrefillInferRpcServer -from lightllm.common.mem_manager import MemoryManager import torch.multiprocessing as mp from lightllm.server.pd_io_struct import KVMoveTask -from lightllm.utils.net_utils import find_available_port from lightllm.utils.retry_utils import retry -from rpyc.utils.classic import obtain from rpyc import AsyncResult from lightllm.utils.net_utils import get_hostname_ip from ..task_queue import TaskQueue -from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_unique_server_name @@ -33,282 +28,6 @@ logger = init_logger(__name__) -@dataclass -class TransProcessObj: - decode_node_id: str = None - rpyc_conn: object = None # rpyc_con 的连接对象 - process: mp.Process = None - task_in_queue: mp.Queue = None - task_out_queue: mp.Queue = None - nccl_ip: str = None - nccl_port: str = None - device_index: str = None # 使用的gpu序号 - manager: "PrefillKVMoveManager" = None - has_error: bool = False - request_kv_trans_task_queue: TaskQueue = None - request_thread: threading.Thread = None - ready_kv_trans_task_queue: TaskQueue = None - kv_trans_thread: threading.Thread = None - latest_check_time: float = None - - def create( - self, decode_node_id: str, decode_node_ip: str, decode_node_rpyc_port: int, manager: "PrefillKVMoveManager" - ): - con = rpyc.connect( - host=decode_node_ip, port=decode_node_rpyc_port, config={"allow_pickle": True}, keepalive=True - ) - nccl_ip = manager.host_ip - nccl_port = find_available_port(manager.args.pd_p_allowed_port_min, manager.args.pd_p_allowed_port_max) - if nccl_port is None: - raise Exception("no pd nccl port can be used") - - from .prefill_trans_process import start_prefill_trans_process - - device_index = manager.get_next_device_index() # 分配 trans 进程使用的显卡 - task_in_queue = mp.Queue() - task_out_queue = mp.Queue() - proc = start_prefill_trans_process( - manager.args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, manager.mem_queues - ) - assert task_out_queue.get(timeout=30) == "proc_start" - manager._put_mem_manager_to_mem_queue() - assert task_out_queue.get(timeout=60) == "get_mem_managers_ok" - prefill_node_id = manager.args.pd_node_id - # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 - max_kv_trans_token_num = obtain( - con.root.build_trans_process(prefill_node_id, nccl_ip, nccl_port, manager.args.max_total_token_num) - ) - self.max_kv_trans_token_num = max_kv_trans_token_num - assert task_out_queue.get(timeout=60) == "nccl_ok" - - self.decode_node_id = decode_node_id - self.rpyc_conn = con - self.process = proc - self.task_in_queue = task_in_queue - self.task_out_queue = task_out_queue - self.nccl_port = nccl_port - self.nccl_ip = nccl_ip - self.device_index = device_index - self.manager = manager - self.latest_check_time = time.time() - - self.request_kv_trans_task_queue = TaskQueue( - get_func=self._get_request_tasks, fail_func=self.manager.put_to_release_task_queue - ) - self.request_thread = threading.Thread(target=self.request_kv_trans_loop, daemon=True) - self.request_thread.start() - - self.ready_kv_trans_task_queue = TaskQueue(lambda datas: datas[0:1], self.manager.put_to_release_task_queue) - self.kv_trans_thread = threading.Thread(target=self.kv_trans_handle_loop, daemon=True) - self.kv_trans_thread.start() - return - - def _get_request_tasks(self, datas: List[KVMoveTask]): - ans_list = [] - token_num = 0 - for task in datas: - if token_num + len(task.prefill_token_indexes) <= self.max_kv_trans_token_num: - ans_list.append(task) - token_num += len(task.prefill_token_indexes) - else: - break - return ans_list - - def check_trans_process(self, raise_exception=True): - process = psutil.Process(self.process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - self.set_has_error() - if raise_exception: - raise Exception(f"trans process: {self.process.pid} is dead") - return - - def check_connect(self, raise_exception=True): - try: - self.rpyc_conn.root.check_alive() - except BaseException as e: - self.set_has_error() - if raise_exception: - raise e - return - - def timer_check_status(self, raise_exception=True): - if time.time() - self.latest_check_time >= 2.0: - self.latest_check_time = time.time() - self.check_trans_process(raise_exception=raise_exception) - self.check_connect(raise_exception=raise_exception) - if self.has_error: - self.manager.remove_trans_obj(self.decode_node_id) - return - - def request_kv_trans_loop(self): - func_name = self.request_kv_trans_loop.__name__ - - while not self.has_error: - move_tasks: List[KVMoveTask] = self.request_kv_trans_task_queue.get_tasks( - log_tag="request_kv_trans_task_queue" - ) - if len(move_tasks) == 0: - # 周期检查通信状态 - self.timer_check_status(raise_exception=False) - time.sleep(0.01) - continue - try: - self.timer_check_status(raise_exception=True) - for move_task in move_tasks: - logger.info( - f"{func_name} get task {move_task.to_prefill_log_info()} " - f"queue time {move_task.get_cost_time()} s " - ) - - trans_move_tasks = [copy.copy(move_task) for move_task in move_tasks] - for trans_move_task in trans_move_tasks: - trans_move_task.prefill_token_indexes = None - - mark_start = time.time() - move_kv_lens = self.rpyc_conn.root.request_data_transfer(trans_move_tasks) - move_kv_lens = obtain(move_kv_lens) - request_data_transfer_cost_time = time.time() - mark_start - - logger.info( - f"{func_name} request_data_transfer ok, {move_tasks[0].to_prefill_log_info()}" - f" cost time: {request_data_transfer_cost_time} s" - ) - - ok_trans_list = [] - for i, move_task in enumerate(move_tasks.copy()): - if move_kv_lens[i] is not None: - move_task.move_kv_len = move_kv_lens[i] - ok_trans_list.append(move_task) - move_tasks.remove(move_task) - else: - logger.info(f"prefill node kv move task req_id: {move_task.id()} not send, decode is busy") - - if len(ok_trans_list) != 0: - self.ready_kv_trans_task_queue.put(ok_trans_list) - - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.manager.remove_trans_obj(self.decode_node_id) - self.request_kv_trans_task_queue.clear_tasks() - - finally: - self.manager.put_to_release_task_queue(move_tasks) - - logger.error(f"{func_name}, decode id {self.decode_node_id} device_index {self.device_index} thread quit") - return - - def _transfer_kv(self, move_tasks: List[KVMoveTask]): - with self.manager.device_locks[self.device_index]: - self.task_in_queue.put(move_tasks.copy(), timeout=10) - assert self.task_out_queue.get(timeout=60) == "ok" - self.manager.put_to_release_task_queue(move_tasks) - - logger.info( - f"_transfer_kv data ok, req_id: {move_tasks[0].id()}" - f" cost total time: {move_tasks[0].get_cost_time()} s" - ) - move_tasks.clear() - - def kv_trans_handle_loop(self): - func_name = self.kv_trans_handle_loop.__name__ - while not self.has_error: - move_tasks: List[List[KVMoveTask]] = self.ready_kv_trans_task_queue.get_tasks( - log_tag="ready_kv_trans_task_queue" - ) - if len(move_tasks) == 0: - self.timer_check_status(raise_exception=False) - time.sleep(0.01) - continue - - if len(move_tasks) != 1: - logger.error(f"error get kv trans move_tasks, must be 1, get {len(move_tasks)}") - assert len(move_tasks) == 1 - - move_tasks = move_tasks[0] - - try: - self.timer_check_status(raise_exception=True) - for move_task in move_tasks: - logger.info( - f"{func_name} get task {move_task.to_prefill_log_info()} to start kv move" - f"queue time {move_task.get_cost_time()} s " - ) - - if not kv_trans_use_p2p(): - with self.manager.kv_trans_lock: - self._transfer_kv(move_tasks) - else: - self._transfer_kv(move_tasks) - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.manager.remove_trans_obj(self.decode_node_id) - self.ready_kv_trans_task_queue.clear_tasks() - finally: - self.manager.put_to_release_task_queue(move_tasks) - - logger.error(f"trans kv thread, decode id {self.decode_node_id} device_index {self.device_index} thread quit") - return - - def wait_thread_quit(self): - if self.request_thread is not None: - if self.request_thread.is_alive(): - try: - self.request_thread.join() - except: - pass - if self.kv_trans_thread is not None: - if self.kv_trans_thread.is_alive(): - try: - self.kv_trans_thread.join() - except: - pass - return - - def has_error_status(self): - try: - assert self.has_error is False - assert self.request_thread.is_alive() - assert self.kv_trans_thread.is_alive() - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - return True - - return False - - def set_has_error(self): - self.has_error = True - try: - self.request_kv_trans_task_queue.has_error = True - self.ready_kv_trans_task_queue.has_error = True - except: - pass - return - - def __del__(self): - logger.error(f"trans obj del start, decode node id {self.decode_node_id} device_index {self.device_index}") - - try: - self.set_has_error() - self.wait_thread_quit() - if self.request_kv_trans_task_queue is not None: - self.request_kv_trans_task_queue.clear_tasks() - if self.ready_kv_trans_task_queue is not None: - self.ready_kv_trans_task_queue.clear_tasks() - except BaseException as e: - logger.exception(str(e)) - - logger.error(f"trans obj deled, decode node id {self.decode_node_id} device_index {self.device_index}") - - # 强制关闭连接和杀掉传输进程 - if self.process is not None: - logger.warning(f"prefill trans process {self.process.pid} is killed") - os.kill(self.process.pid, signal.SIGKILL) - pass - - class PrefillKVMoveManager: def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.args = args @@ -322,7 +41,11 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.info_queue = info_queue self.mem_queues = mem_queues self.infer_rpyc_objs: List[PDPrefillInferRpcServer] = [] - self.node_id_to_trans_obj: Dict[str, TransProcessObj] = {} + + from .prefill_trans_obj import KVTransConnectObj + + self.connect_id_to_trans_obj: Dict[str, KVTransConnectObj] = {} + for port in self.args.pd_node_infer_rpyc_ports: socket_path = f"/tmp/{get_unique_server_name()}_prefill_node_infer_rpyc_{port}" from rpyc.utils.factory import unix_connect @@ -337,15 +60,46 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.infer_rpyc_lock = threading.Lock() self.kv_trans_lock = threading.Lock() - # 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。 - self.device_locks = [threading.Lock() for _ in range(self.node_world_size)] - # 释放token的task队列 self.release_task_queue = TaskQueue(lambda datas: datas[0:KV_MOVE_MAX_NUM], fail_func=None) self.release_tasks_thread = threading.Thread(target=self.handle_release_task_loop, daemon=True) self.release_tasks_thread.start() + + from .prefill_trans_obj import KVTransProcess + + self.kv_trans_processes: List[KVTransProcess] = [None] * self.node_world_size + for device_id in range(self.node_world_size): + self.kv_trans_processes[device_id] = KVTransProcess() + assert self.kv_trans_processes[device_id].init_all(device_id, self) + return + # ================================================================================== + # 主任务循环,接收需要进行kv传输的请求进行处理 + # ================================================================================== + + def task_dispatcher_loop(self): + try: + # 获取任务,并分发给相关卡的处理队列 + while True: + move_task: KVMoveTask = self.info_queue.get() + try: + trans_obj = self.__get_trans_obj(move_task) + trans_obj.request_kv_trans_task_queue.put(move_task) + except BaseException as e: + logger.exception(str(e)) + self.put_to_release_task_queue(move_task) + finally: + trans_obj = None + + except (BaseException, RuntimeError) as e: + logger.exception(str(e)) + raise e + + # ================================================================================== + # 请求出错或者完成kv传输后的处理队列和线程loop + # ================================================================================== + def put_to_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): if isinstance(task, KVMoveTask): self.release_task_queue.put(task) @@ -364,61 +118,34 @@ def handle_release_task_loop(self): self._remove_req_refs_from_prompt_cache(handle_list) return - def get_next_device_index(self): - counts = [0 for _ in range(self.node_world_size)] - for obj in self.node_id_to_trans_obj.values(): - counts[obj.device_index] += 1 - device_index = int(np.argmin(counts)) - return device_index + # ================================================================================== + # 定时检测传输进程的健康状态,出现问题拉崩整个系统触发重启 + # ================================================================================== - def get_trans_obj(self, task: KVMoveTask): - self.remove_dead_trans_obj() - if task.decode_node.node_id not in self.node_id_to_trans_obj: - gc.collect() - trans_obj = TransProcessObj() - trans_obj.create(task.decode_node.node_id, task.decode_node.ip, task.decode_node.rpyc_port, self) - self.node_id_to_trans_obj[task.decode_node.node_id] = trans_obj - return self.node_id_to_trans_obj[task.decode_node.node_id] - - def remove_trans_obj(self, decode_node_id): - if decode_node_id in self.node_id_to_trans_obj: - trans_obj = self.node_id_to_trans_obj.pop(decode_node_id, None) - if trans_obj is not None: - trans_obj.set_has_error() - logger.error(f"remove tran obj id {trans_obj.decode_node_id}") - return - - def remove_dead_trans_obj(self): - del_node_ids = [] - for node_id, t_obj in self.node_id_to_trans_obj.items(): - if t_obj.has_error_status(): - del_node_ids.append(node_id) - - for node_id in del_node_ids: - self.node_id_to_trans_obj.pop(node_id, None) - - if len(del_node_ids) != 0: - gc.collect() - return - - def task_dispatcher_loop(self): + def check_trans_process_loop(self): try: - # 获取任务,并分发给相关卡的处理队列 while True: - move_task: KVMoveTask = self.info_queue.get() - try: - trans_obj = self.get_trans_obj(move_task) - trans_obj.request_kv_trans_task_queue.put(move_task) - except BaseException as e: - logger.exception(str(e)) - self.put_to_release_task_queue(move_task) - finally: - trans_obj = None + for device_id in range(self.node_world_size): + if not self.kv_trans_processes[device_id].is_trans_process_health(): + raise Exception(f"device_id {device_id} kv process is unhealth") + time.sleep(10.0) except (BaseException, RuntimeError) as e: logger.exception(str(e)) + + for device_id in range(self.node_world_size): + self.kv_trans_processes[device_id].killself() + + # 杀掉当前进程的父进程(router), 触发全局崩溃 + os.kill(os.getppid(), signal.SIGKILL) + os.kill(os.getpid(), signal.SIGKILL) raise e + # ================================================================================== + # 与推理进程交互接口, _remove_req_refs_from_prompt_cache 和 + # _put_mem_manager_to_mem_queue 都是通过 rpyc 与推理进程进行交互的接口 + # ================================================================================== + def _remove_req_refs_from_prompt_cache(self, tasks: List[KVMoveTask]): with self.infer_rpyc_lock: dp_to_tasks = collections.defaultdict(list) @@ -446,6 +173,54 @@ async def wait_all_future_finish(self, futures: List[AsyncResult]): await asyncio.gather(*[asyncio.to_thread(future.wait) for future in futures]) return + # ================================================================================== + # 辅助功能接口 + # ================================================================================== + + def get_next_device_index(self): + counts = [0 for _ in range(self.node_world_size)] + for obj in self.connect_id_to_trans_obj.values(): + counts[obj.device_index] += 1 + device_index = int(np.argmin(counts)) + return device_index + + def remove_trans_obj(self, connect_id): + if connect_id in self.connect_id_to_trans_obj: + trans_obj = self.connect_id_to_trans_obj.pop(connect_id, None) + if trans_obj is not None: + trans_obj.set_has_error() + logger.error(f"remove tran obj decode_node_id {trans_obj.decode_node_id}") + return + + def __get_trans_obj(self, task: KVMoveTask): + self.__remove_dead_trans_obj() + # 如果已经存在连接对象,直接返回 + for obj in self.connect_id_to_trans_obj.values(): + if obj.decode_node_id == task.decode_node.node_id: + return obj + + # 如果不存在连接对象,创建新的连接对象 + gc.collect() + from .prefill_trans_obj import KVTransConnectObj + + trans_obj = KVTransConnectObj() + trans_obj.create(task.decode_node.node_id, task.decode_node.ip, task.decode_node.rpyc_port, self) + self.connect_id_to_trans_obj[trans_obj.connect_id] = trans_obj + return trans_obj + + def __remove_dead_trans_obj(self): + del_connect_ids = [] + for connect_id, t_obj in self.connect_id_to_trans_obj.items(): + if t_obj.has_error_status(): + del_connect_ids.append(connect_id) + + for connect_id in del_connect_ids: + self.connect_id_to_trans_obj.pop(connect_id, None) + + if del_connect_ids: + gc.collect() + return + def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event): import lightllm.utils.rpyc_fix_utils as _ @@ -454,6 +229,8 @@ def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp. graceful_registry(inspect.currentframe().f_code.co_name) manager = PrefillKVMoveManager(args, info_queue, mem_queues) + kv_trans_process_check = threading.Thread(target=manager.check_trans_process_loop, daemon=True) + kv_trans_process_check.start() event.set() # 进入主循环 manager.task_dispatcher_loop() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py new file mode 100644 index 000000000..f53761e09 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py @@ -0,0 +1,380 @@ +import time +import rpyc +import copy +import uuid +import numpy as np +import psutil +import threading +from dataclasses import dataclass +from typing import List, Dict, Union +from lightllm.utils.log_utils import init_logger +import torch.multiprocessing as mp +from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup +from rpyc.utils.classic import obtain +from ..task_queue import TaskQueue +from lightllm.utils.device_utils import kv_trans_use_p2p +from lightllm.utils.time_utils import TimeChecker +from .prefill_kv_move_manager import PrefillKVMoveManager +from lightllm.utils.net_utils import find_available_port +from ..utils import join_if_alive, clear_queue + +logger = init_logger(__name__) + + +@dataclass +class KVTransConnectObj: + connect_id: str = None + decode_node_id: int = None + rpyc_conn: object = None # rpyc_con 的连接对象 + kv_trans_process: "KVTransProcess" = None + device_index: int = None # 使用的gpu序号 + manager: "PrefillKVMoveManager" = None + has_error: bool = False + request_kv_trans_task_queue: TaskQueue = None + request_thread: threading.Thread = None + ready_kv_trans_task_queue: TaskQueue = None + kv_trans_thread: threading.Thread = None + timer_checker: TimeChecker = None + + # ================================================================================== + # 构建传输通信对象 + # ================================================================================== + + def create( + self, decode_node_id: int, decode_node_ip: str, decode_node_rpyc_port: int, manager: "PrefillKVMoveManager" + ): + device_index = manager.get_next_device_index() # 分配使用的显卡index + self.kv_trans_process = manager.kv_trans_processes[device_index] + prefill_node_id = manager.args.pd_node_id + self.connect_id = str(uuid.uuid4()) + self.decode_node_id = decode_node_id + self.prefill_node_id = prefill_node_id + self.device_index = device_index + self.manager = manager + self.timer_checker = TimeChecker(6) + + con = rpyc.connect( + host=decode_node_ip, + port=decode_node_rpyc_port, + config={"allow_pickle": True, "sync_request_timeout": 60}, + keepalive=True, + ) + + self.rpyc_conn = con + + # 创建 nccl 连接 + with self.kv_trans_process.device_lock: + clear_queue(self.kv_trans_process.task_out_queue) + + self.kv_trans_process.task_in_queue.put( + PDTransJoinInfo( + prefill_id=prefill_node_id, + prefill_device_id=device_index, + pd_prefill_nccl_ip=manager.host_ip, + pd_prefill_nccl_port=self.kv_trans_process.kv_trans_port, + decode_id=decode_node_id, + decode_device_id=-1, + connect_id=self.connect_id, + ) + ) + + # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 + max_kv_trans_token_num = obtain( + con.root.build_trans_connect( + prefill_node_id, + manager.host_ip, + self.kv_trans_process.kv_trans_port, + manager.args.max_total_token_num, + self.connect_id, + ) + ) + self.max_kv_trans_token_num = max_kv_trans_token_num + assert self.kv_trans_process.task_out_queue.get(timeout=60) == "nccl_ok" + + self.request_kv_trans_task_queue = TaskQueue( + get_func=self._get_request_tasks, fail_func=self.manager.put_to_release_task_queue + ) + self.request_thread = threading.Thread(target=self.request_kv_trans_loop, daemon=True) + self.request_thread.start() + + self.ready_kv_trans_task_queue = TaskQueue(lambda datas: datas[0:1], self.manager.put_to_release_task_queue) + self.kv_trans_thread = threading.Thread(target=self.kv_trans_handle_loop, daemon=True) + self.kv_trans_thread.start() + + logger.info(f"create KVTransConnectObj success: {self.to_log_info()}") + return + + def _get_request_tasks(self, datas: List[KVMoveTask]): + """ + 根据可以p和d节点间协商得到的 max_kv_trans_token_num 限制,将排队等待 + 传输的请求打包成一个可以传输的list组。 + """ + ans_list = [] + token_num = 0 + for task in datas: + if token_num + len(task.prefill_token_indexes) <= self.max_kv_trans_token_num: + ans_list.append(task) + token_num += len(task.prefill_token_indexes) + else: + break + return ans_list + + # ================================================================================== + # 与 decode 节点进行元数据交互,申请锁定资源准备进行kv的传输 + # ================================================================================== + def request_kv_trans_loop(self): + func_name = self.request_kv_trans_loop.__name__ + + while not self.has_error: + move_tasks: List[KVMoveTask] = self.request_kv_trans_task_queue.get_tasks( + log_tag="request_kv_trans_task_queue" + ) + if len(move_tasks) == 0: + self.timer_check_status(raise_exception=False) + time.sleep(0.01) + continue + try: + self.timer_check_status(raise_exception=True) + for move_task in move_tasks: + move_task.connect_id = self.connect_id + logger.info( + f"{func_name} get task {move_task.to_prefill_log_info()} " + f"queue time {move_task.get_cost_time()} s " + ) + + trans_move_tasks = [copy.copy(move_task) for move_task in move_tasks] + for trans_move_task in trans_move_tasks: + trans_move_task.prefill_token_indexes = None + + mark_start = time.time() + move_kv_lens = self.rpyc_conn.root.request_data_transfer(trans_move_tasks) + move_kv_lens = obtain(move_kv_lens) + request_data_transfer_cost_time = time.time() - mark_start + + logger.info( + f"{func_name} request_data_transfer ok, {move_tasks[0].to_prefill_log_info()}" + f" cost time: {request_data_transfer_cost_time} s" + ) + + ok_trans_list = [] + for i, move_task in enumerate(move_tasks.copy()): + if move_kv_lens[i] is not None: + move_task.move_kv_len = move_kv_lens[i] + ok_trans_list.append(move_task) + move_tasks.remove(move_task) + else: + logger.info(f"prefill node kv move task req_id: {move_task.id()} not send, decode is busy") + + if ok_trans_list: + self.ready_kv_trans_task_queue.put( + ok_trans_list, error_handle_func=self.manager.put_to_release_task_queue + ) + + except BaseException as e: + logger.exception(str(e)) + self.set_has_error() + self.request_kv_trans_task_queue.clear_tasks() + + finally: + # 将没有申请成功的请求放入到释放队列中 + self.manager.put_to_release_task_queue(move_tasks) + + logger.error(f"{func_name}, {self.to_log_info()} thread quit") + return + + # ================================================================================== + # 将准备好 kv 传输的请求进行 kv 传输 + # ================================================================================== + def _transfer_kv(self, move_tasks: List[KVMoveTask]): + with self.kv_trans_process.device_lock: + clear_queue(self.kv_trans_process.task_out_queue) + kv_move_group = KVMoveTaskGroup(tasks=move_tasks.copy(), connect_id=self.connect_id) + self.kv_trans_process.task_in_queue.put(kv_move_group, timeout=10) + assert self.kv_trans_process.task_out_queue.get(timeout=60) == "ok" + self.manager.put_to_release_task_queue(move_tasks) + + logger.info( + f"_transfer_kv data ok, req_id: {move_tasks[0].id()}" + f" cost total time: {move_tasks[0].get_cost_time()} s" + ) + move_tasks.clear() + + def kv_trans_handle_loop(self): + func_name = self.kv_trans_handle_loop.__name__ + while not self.has_error: + move_tasks: List[List[KVMoveTask]] = self.ready_kv_trans_task_queue.get_tasks( + log_tag="ready_kv_trans_task_queue" + ) + if len(move_tasks) == 0: + self.timer_check_status(raise_exception=False) + time.sleep(0.01) + continue + + if len(move_tasks) != 1: + logger.error(f"error get kv trans move_tasks, must be 1, get {len(move_tasks)}") + assert len(move_tasks) == 1 + + move_tasks: List[KVMoveTask] = move_tasks[0] + + try: + self.timer_check_status(raise_exception=True) + for move_task in move_tasks: + logger.info( + f"{func_name} get task {move_task.to_prefill_log_info()} to start kv move" + f"queue time {move_task.get_cost_time()} s " + ) + + if not kv_trans_use_p2p(): + with self.manager.kv_trans_lock: + self._transfer_kv(move_tasks) + else: + self._transfer_kv(move_tasks) + except BaseException as e: + logger.exception(str(e)) + self.set_has_error() + self.ready_kv_trans_task_queue.clear_tasks() + finally: + self.manager.put_to_release_task_queue(move_tasks) + + logger.error(f"trans kv thread, {self.to_log_info()} thread quit") + return + + # ================================================================================== + # 错误处理检测操作的一些通用函数 + # ================================================================================== + + def has_error_status(self): + try: + assert self.has_error is False + assert self.request_thread.is_alive() + assert self.kv_trans_thread.is_alive() + except BaseException as e: + logger.exception(str(e)) + self.set_has_error() + return True + + return False + + def timer_check_status(self, raise_exception=True): + if self.timer_checker.has_exceeded(): + try: + self.rpyc_conn.root.check_alive() + assert self.kv_trans_process.is_trans_process_health() + except BaseException as e: + logger.error(f"pid {self.kv_trans_process.process.pid} check failed") + logger.exception(str(e)) + + self.set_has_error() + if raise_exception: + raise e + + return + + def set_has_error(self): + """ + 将当前传输对象标记为有错误,这样可以防止请求放入到处理队列中 + """ + self.has_error = True + + if self.request_kv_trans_task_queue is not None: + self.request_kv_trans_task_queue.has_error = True + + if self.ready_kv_trans_task_queue is not None: + self.ready_kv_trans_task_queue.has_error = True + + if self.manager is not None: + self.manager.remove_trans_obj(self.connect_id) + return + + def __del__(self): + """ + 函数中有很多判断是否是None的操作,主要是为了避免一些异常流程的del行为不报错。 + """ + logger.error(f"trans obj del start, info: {self.to_log_info()}") + + try: + self.set_has_error() + + join_if_alive(self.request_thread) + join_if_alive(self.kv_trans_thread) + + # 将未处理的请求,清理掉,clear_tasks 会将没处理完的请求 + # 放入到 manager 资源释放队列中 + if self.request_kv_trans_task_queue is not None: + self.request_kv_trans_task_queue.clear_tasks() + if self.ready_kv_trans_task_queue is not None: + self.ready_kv_trans_task_queue.clear_tasks() + + # 传输进程清理掉 nccl 连接 + if self.connect_id is not None: + self.kv_trans_process.task_in_queue.put( + PDTransLeaveInfo( + decode_id=self.decode_node_id, prefill_id=self.prefill_node_id, connect_id=self.connect_id + ) + ) + + except BaseException as e: + logger.exception(str(e)) + + logger.error(f"trans obj deled, info: {self.to_log_info()}") + + def to_log_info(self): + log = f"connect_id: {self.connect_id} " + log += f"decode_node_id: {self.decode_node_id} " + log += f"prefill_node_id: {self.prefill_node_id} " + log += f"device_index: {self.device_index} " + return log + + +@dataclass +class KVTransProcess: + process: mp.Process = None + # 需要每个卡有一个锁来规划每次只能有一个 connection obj 操作对应显卡上的传输任务。 + device_lock: threading.Lock = None + task_in_queue: mp.Queue = None + task_out_queue: mp.Queue = None + device_id: int = None + kv_trans_port: int = None + + def init_all(self, device_id: int, manager: "PrefillKVMoveManager"): + self.device_id = device_id + self.device_lock = threading.Lock() + self.task_in_queue = mp.Queue() + self.task_out_queue = mp.Queue() + self.kv_trans_port = find_available_port(manager.args.pd_p_allowed_port_min, manager.args.pd_p_allowed_port_max) + + try: + from .prefill_trans_process import start_prefill_trans_process + + self.process = start_prefill_trans_process( + manager.args, + manager.host_ip, + self.kv_trans_port, + device_id, + self.task_in_queue, + self.task_out_queue, + manager.mem_queues, + ) + assert self.task_out_queue.get(timeout=30) == "proc_start" + manager._put_mem_manager_to_mem_queue() + assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" + + return True + except Exception as e: + logger.warning(f"Failed start kv trans process for device {device_id}: {e}") + logger.exception(str(e)) + return False + + def is_trans_process_health(self): + try: + process = psutil.Process(self.process.pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + logger.error(f"kv trans process for device: {self.device_id} dead!!!") + return False + else: + return True + except: + return False + + def killself(self): + self.process.kill() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index 1973aabac..3e42a532d 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -2,25 +2,94 @@ import time import sys import inspect +import threading import torch.multiprocessing as mp -from typing import List, Dict +from torch.distributed import TCPStore +from datetime import timedelta +from typing import List, Dict, Union from lightllm.utils.log_utils import init_logger from lightllm.common.mem_manager import MemoryManager -from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.graceful_utils import graceful_registry +from lightllm.distributed.pynccl import StatelessP2PProcessGroup, PyNcclCommunicator logger = init_logger(__name__) -# device_index 是用来指示,当前传输进程使用的用于数据传输的显卡id -# 当模型是多卡推理的时候,需要传输的 kv 需要先移动到 device_index -# 指定的显卡上,然后再进行传输,因为torch nccl 限制了只能操作一张显卡上的数据 + +def _handle_kvmove_task( + move_tasks: List[KVMoveTask], + task_out_queue: mp.Queue, + mem_managers: List[MemoryManager], + connect_id_to_comm: Dict[str, PyNcclCommunicator], + connect_id: str, + dp_size_in_node: int, +): + total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) + try: + device_index = connect_id_to_comm[connect_id].device.index + start = time.time() + if total_move_kv_len != 0: + logger.info(f"trans start: {move_tasks[0].to_prefill_log_info()}") + cur_mem = mem_managers[device_index] + if kv_trans_use_p2p(): + cur_mem.send_to_decode_node_p2p( + move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id] + ) + else: + cur_mem.send_to_decode_node(move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id]) + logger.info(f"trans finished: {move_tasks[0].to_prefill_log_info()} move len: {total_move_kv_len}") + torch.cuda.synchronize() + logger.info( + f"trans cost time: {(time.time() - start)}," + f"move_total_kv_len: {total_move_kv_len}, {move_tasks[0].to_prefill_log_info()}" + ) + task_out_queue.put("ok") + except BaseException as e: + logger.exception(str(e)) + task_out_queue.put("fail") + + +def _handle_decode_join( + node_info: PDTransJoinInfo, + task_out_queue: mp.Queue, + connect_id_to_comm: Dict[str, PyNcclCommunicator], + store: TCPStore, +): + try: + logger.info(f"connect start {node_info}") + src_id = node_info.prefill_id + dest_id = node_info.connect_id + logger.info(f"connect src_id {src_id} dest_id {dest_id}") + result_list = [] + + def async_connect(): + torch.cuda.set_device(node_info.prefill_device_id) + group = StatelessP2PProcessGroup.create(src_id=src_id, dest_id=dest_id, is_server=True, store=store) + comm = PyNcclCommunicator(group, node_info.prefill_device_id) + result_list.append(comm) + return + + connect_task = threading.Thread(target=async_connect, daemon=True) + connect_task.start() + connect_task.join(timeout=36) + if connect_task.is_alive(): + raise Exception(f"{node_info} connect time out") + + connect_id_to_comm[node_info.connect_id] = result_list[0] + logger.info(f"{node_info} kv trans connected!") + task_out_queue.put("nccl_ok") + except Exception as e: + task_out_queue.put("nccl_fail") + logger.warning(f"error while connect to decode node: {e} node_info {node_info}") + + def _init_env( args, - device_index: int, - nccl_ip, - nccl_port, + store_ip, + store_port, + device_id, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue], @@ -33,67 +102,54 @@ def _init_env( os.environ["NCCL_SOCKET_NTHREADS"] = "1" torch.backends.cudnn.enabled = False - dp_size_in_node = max(1, args.dp // args.nnodes) - node_world_size = args.tp // args.nnodes - try: - # 注册graceful 退出的处理 + torch.cuda.set_device(device_id) graceful_registry(inspect.currentframe().f_code.co_name) - torch.cuda.set_device(device_index) - + master_store = TCPStore( + host_name=store_ip, port=store_port, is_master=True, use_libuv=True, timeout=timedelta(seconds=30) + ) + dp_size_in_node = max(1, args.dp // args.nnodes) task_out_queue.put("proc_start") mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] - assert len(mem_managers) == node_world_size task_out_queue.put("get_mem_managers_ok") - import torch.distributed as dist - from datetime import timedelta + connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} - dist.init_process_group( - "nccl", init_method=f"tcp://{nccl_ip}:{nccl_port}", rank=0, world_size=2, timeout=timedelta(seconds=60) - ) - task_out_queue.put("nccl_ok") while True: - move_tasks: List[KVMoveTask] = task_in_queue.get() - total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) - try: - start = time.time() - if total_move_kv_len != 0: - logger.info(f"trans start: {move_tasks[0].to_prefill_log_info()}") - cur_mem = mem_managers[device_index] - if kv_trans_use_p2p(): - cur_mem.send_to_decode_node_p2p(move_tasks, mem_managers, dp_size_in_node) - else: - cur_mem.send_to_decode_node(move_tasks, mem_managers, dp_size_in_node) - logger.info(f"trans finished: {move_tasks[0].to_prefill_log_info()} move len: {total_move_kv_len}") - torch.cuda.synchronize() - logger.info( - f"trans cost time: {(time.time() - start)}," - f"move_total_kv_len: {total_move_kv_len}, {move_tasks[0].to_prefill_log_info()}" + task: Union[KVMoveTaskGroup, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() + if isinstance(task, KVMoveTaskGroup): + _handle_kvmove_task( + task.tasks, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node ) - task_out_queue.put("ok") - except BaseException as e: - logger.exception(str(e)) - task_out_queue.put("fail") - raise e - except BaseException as e: - logger.exception(str(e)) - sys.exit(-1) - return + elif isinstance(task, PDTransJoinInfo): + _handle_decode_join(task, task_out_queue, connect_id_to_comm, master_store) + elif isinstance(task, PDTransLeaveInfo): + if task.connect_id in connect_id_to_comm: + connect_id_to_comm[task.connect_id].destroy() + connect_id_to_comm.pop(task.connect_id, None) + logger.info(f"destory {task} nccl communicator.") + else: + logger.error(f"connect id {task.connect_id} dont exist in connect_id_to_comm") + else: + logger.warning(f"unexpected task type: {task}") + + except Exception as e: + logger.error(f"Fatal error happened in kv trans process: {e}") + pass def start_prefill_trans_process( args, - device_index: int, - nccl_ip, - nccl_port, + store_ip, + store_port, + device_id, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): proc = mp.Process( - target=_init_env, args=(args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, mem_queues) + target=_init_env, args=(args, store_ip, store_port, device_id, task_in_queue, task_out_queue, mem_queues) ) proc.start() assert proc.is_alive() - logger.info(f"trans kv process start, nccl_ip: {nccl_ip}, nccl_port: {nccl_port}") + logger.info(f"prefill trans kv process for device: {device_id} started!") return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py index 9dd4b3c5f..7b856e54a 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py @@ -15,8 +15,10 @@ def __init__(self, get_func, fail_func): def size(self): return len(self.datas) - def put(self, obj): + def put(self, obj, error_handle_func=None): if self.has_error: + if error_handle_func is not None: + error_handle_func(obj) raise Exception("has error") with self.lock: diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py new file mode 100644 index 000000000..cd1360fd0 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py @@ -0,0 +1,20 @@ +import threading +import torch.multiprocessing as mp +from queue import Empty + + +def join_if_alive(thread: threading.Thread): + if thread is not None and thread.is_alive(): + try: + thread.join() + except Exception: + pass + return + + +def clear_queue(queue: mp.Queue): + while not queue.empty(): + try: + queue.get_nowait() + except Empty: + break diff --git a/lightllm/utils/process_check.py b/lightllm/utils/process_check.py index 75a8f890d..00cc258bf 100644 --- a/lightllm/utils/process_check.py +++ b/lightllm/utils/process_check.py @@ -42,5 +42,5 @@ def start_parent_check_thread(): """ 检测父进程是否健康,如果出现问题,清理退出所有进程 """ - thread = threading.Thread(target=check_parent_alive) + thread = threading.Thread(target=check_parent_alive, daemon=True) thread.start() diff --git a/lightllm/utils/time_utils.py b/lightllm/utils/time_utils.py new file mode 100644 index 000000000..648108d2b --- /dev/null +++ b/lightllm/utils/time_utils.py @@ -0,0 +1,17 @@ +import time + + +class TimeChecker: + def __init__(self, threshold): + self.threshold = threshold + self.last_checked = time.time() + + def has_exceeded(self): + current_time = time.time() + if (current_time - self.last_checked) > self.threshold: + self._reset() + return True + return False + + def _reset(self): + self.last_checked = time.time()