From fda8c37980c3c48efb92ddf692b08dbf7cce4edf Mon Sep 17 00:00:00 2001 From: chenjian Date: Thu, 14 Dec 2023 15:39:00 +0800 Subject: [PATCH] Improve robustness for llm (#2321) * add inference load balancer for fastdeploy llm * add inference load balance controller for llm * add ic for llm * add ic for llm * add fastdeploy ic for llm * add fastdeploy ic to llm * Fix asyncio.CancelError exception * Improve robust for llm service * Improve robust for llm service * Add detailed log for llm service * Add detailed log for llm service * Add detailed log for llm service * Add detailed log for llm service * Add detailed log for llm service --- llm/fastdeploy_llm/config.py | 10 ++- llm/fastdeploy_llm/engine.py | 5 -- llm/fastdeploy_llm/serving/triton_model.py | 85 ++++++++++++++++------ llm/fastdeploy_llm/utils/launch_infer.py | 11 ++- llm/fastdeploy_llm/utils/logging_util.py | 19 +++++ 5 files changed, 96 insertions(+), 34 deletions(-) diff --git a/llm/fastdeploy_llm/config.py b/llm/fastdeploy_llm/config.py index d78c98952f..47b6f05a61 100644 --- a/llm/fastdeploy_llm/config.py +++ b/llm/fastdeploy_llm/config.py @@ -25,12 +25,18 @@ def __init__(self, model_dir, decode_strategy="sampling", mp_num=None): self.model_dir = model_dir is_static, rank = check_model(model_dir) + self.log_home = os.getenv("LOG_HOME", ".") + fastdeploy_llm.utils.logging_util.warning_logger = Logger( + name="fastDeploy_llm_serving_warning", + log_file=os.path.join(self.log_home, "fastdeploy_llm_serving_warning.log"), + time_rotation=7, + level=logging.DEBUG) if os.getenv("ENABLE_DEBUG_LOG", "0") == "1": logger.info( "Detect enviroment variable `ENABLE_DEBUG_LOG`, all the debug log information will output to fastdeploy_llm_serving.log." ) fastdeploy_llm.utils.logging_util.logger = Logger( - log_file="fastdeploy_llm_serving.log", + log_file=os.path.join(self.log_home, "fastdeploy_llm_serving.log"), time_rotation=7, level=logging.DEBUG) else: @@ -38,7 +44,7 @@ def __init__(self, model_dir, decode_strategy="sampling", mp_num=None): "The logging level is set as INFO, if more information needed, please execute `export ENABLE_DEBUG_LOG=1` before launching service." ) fastdeploy_llm.utils.logging_util.logger = Logger( - log_file="fastdeploy_llm_serving.log", + log_file=os.path.join(self.log_home, "fastdeploy_llm_serving.log"), time_rotation=7, level=logging.INFO) diff --git a/llm/fastdeploy_llm/engine.py b/llm/fastdeploy_llm/engine.py index e749c15449..5de7e350f4 100644 --- a/llm/fastdeploy_llm/engine.py +++ b/llm/fastdeploy_llm/engine.py @@ -548,11 +548,6 @@ def run(infer_engine): flag_ready_array[rank] = 1 # init done while 1: - if serving_pid > 0 and (not is_process_running(serving_pid)): - print( - "[IMPORTANT] The serving process {} is not running, will terminate engine now.". - format(serving_pid)) - break if flag_begin_array[rank] != 1: continue diff --git a/llm/fastdeploy_llm/serving/triton_model.py b/llm/fastdeploy_llm/serving/triton_model.py index d0b40e05df..6ec50ad750 100644 --- a/llm/fastdeploy_llm/serving/triton_model.py +++ b/llm/fastdeploy_llm/serving/triton_model.py @@ -19,8 +19,10 @@ import time import numpy as np import functools +from collections import defaultdict from fastdeploy_llm.serving.serving_model import ServingModel -from fastdeploy_llm.utils.logging_util import logger +from fastdeploy_llm.utils.logging_util import logger, warning_logger +from fastdeploy_llm.utils.logging_util import error_format, ErrorCode, ErrorType from fastdeploy_llm.task import Task, BatchTask import fastdeploy_llm as fdlm @@ -38,11 +40,14 @@ def stream_call_back(call_back_task, token_tuple, index, is_last_token, out["req_id"] = call_back_task.task_id out["token_ids"] = [token_tuple[0]] out['send_idx'] = index - out["is_end"] = 1 if is_last_token else 0 + out["is_end"] = is_last_token out_tensor = pb_utils.Tensor( "OUT", np.array( [json.dumps(out)], dtype=np.object_)) if is_last_token: + all_token_ids = [t[0] for t in call_back_task.result.completion_tokens] + [token_tuple[0]] + all_strs = "".join([t[1] for t in call_back_task.result.completion_tokens]) + token_tuple[1] + logger.info("Model output for req_id: {} results_all: {} tokens_all: {}".format(call_back_task.task_id, all_strs, all_token_ids)) sender[call_back_task.task_id].send( pb_utils.InferenceResponse([out_tensor]), flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) @@ -68,10 +73,14 @@ def initialize(self, args): using_decoupled = pb_utils.using_decoupled_model_transaction_policy( self.model_config) if not using_decoupled: - raise pb_utils.TritonModelException( - """the model `{}` can generate any number of responses per request, + error_type = ErrorType.Server + error_code = ErrorCode.S0001 + error_info = """the model `{}` can generate any number of responses per request, enable decoupled transaction policy in model configuration to - serve this model""".format(args["model_name"])) + serve this model""".format(args["model_name"]) + error_msg = error_format.format(error_type.name, error_code.name, error_info) + warning_logger.error(error_msg) + raise pb_utils.TritonModelException(error_msg) parameters = self.model_config["parameters"] @@ -112,10 +121,13 @@ def execute(self, requests): if isinstance(data, list): data = data[0] except Exception as e: + error_type = ErrorType.Query + error_code = ErrorCode.C0000 + error_info = "Cannot load json data from request, received data = {} error={}.".format(request_tensor, e) + error_msg = error_format.format(error_type.name, error_code.name, error_info) + warning_logger.error(error_msg) error_res = pb_utils.InferenceResponse( - error=pb_utils.TritonError( - "Cannot load json data from request, error={}.".format( - e))) + error=pb_utils.TritonError(error_msg)) res_sender = request.get_response_sender() res_sender.send( error_res, @@ -127,9 +139,13 @@ def execute(self, requests): try: task.from_dict(data) except Exception as e: - error_res = pb_utils.InferenceResponse(error=pb_utils.TritonError( - "There's error while deserializing data from request, error={}". - format(e))) + error_type = ErrorType.Query + error_code = ErrorCode.C0001 + error_info = "There's error while deserializing data from request, received data = {} error={}".format(data, e) + error_msg = error_format.format(error_type.name, error_code.name, error_info) + warning_logger.error(error_msg) + error_res = pb_utils.InferenceResponse( + error=pb_utils.TritonError(error_msg)) res_sender = request.get_response_sender() res_sender.send( error_res, @@ -140,9 +156,13 @@ def execute(self, requests): if task.task_id is None: task.task_id = str(uuid.uuid4()) if task.task_id in self.response_handler: + error_type = ErrorType.Query + error_code = ErrorCode.C0001 + error_info = "Task id conflict with {}.".format(task.task_id) + error_msg = error_format.format(error_type.name, error_code.name, error_info) + warning_logger.error(error_msg) error_res = pb_utils.InferenceResponse( - error=pb_utils.TritonError( - "Task id conflict with {}.".format(task.task_id))) + error=pb_utils.TritonError(error_msg)) res_sender = request.get_response_sender() res_sender.send( error_res, @@ -153,10 +173,13 @@ def execute(self, requests): try: task.check(self.config.max_dec_len) except Exception as e: + error_type = ErrorType.Query + error_code = ErrorCode.C0001 + error_info = "There's error while checking task, task={} error={}".format(task, e) + error_msg = error_format.format(error_type.name, error_code.name, error_info) + warning_logger.error(error_msg) error_res = pb_utils.InferenceResponse( - error=pb_utils.TritonError( - "There's error while checking task, error={}".format( - e))) + error=pb_utils.TritonError(error_msg)) res_sender = request.get_response_sender() res_sender.send( error_res, @@ -165,9 +188,12 @@ def execute(self, requests): # 5. check if the requests queue is full if self.model.requests_queue.qsize() > self.config.max_queue_num: - error_res = pb_utils.InferenceResponse(error=pb_utils.TritonError( - "The queue is full now(size={}), please wait for a while.". - format(self.model.max_queue_num))) + error_type = ErrorType.Server + error_code = ErrorCode.S0000 + error_info = "The queue is full now(size={}), please wait for a while.".format(self.model.max_queue_num) + error_msg = error_format.format(error_type.name, error_code.name, error_info) + warning_logger.error(error_msg) + error_res = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_msg)) res_sender = request.get_response_sender() res_sender.send( error_res, @@ -195,10 +221,12 @@ def execute(self, requests): try: self.model.add_request(task) except Exception as e: - error_res = pb_utils.InferenceResponse( - error=pb_utils.TritonError( - "There's error while inserting new request, error={}". - format(e))) + error_type = ErrorType.Query + error_code = ErrorCode.C0001 + error_info = "There's error while inserting new request, task={} error={}".format(task, e) + error_msg = error_format.format(error_type.name, error_code.name, error_info) + warning_logger.error(error_msg) + error_res = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_msg)) res_sender = request.get_response_sender() res_sender.send( error_res, @@ -208,5 +236,16 @@ def execute(self, requests): def finalize(self): logger.info("The triton server is going to terminating...") + info_type = ErrorType.Server + info_code = ErrorCode.S0002 + info_msg = error_format.format(info_type.name, info_code.name, "The triton server is going to terminating...") + warning_logger.info(info_msg) self.model.stop() + os.system(""" + bash -c 'pids=$(ps auxww | grep -E "triton_python_backend_stub|multiprocessing.resource_tracker|engine.py" | grep -v grep | awk '"'"'{print $2}'"'"'); + echo $pids; + for pid in ${pids[@]}; do + kill -9 ${pid} + done;' + """) logger.info("The triton server is terminated, byebye.") diff --git a/llm/fastdeploy_llm/utils/launch_infer.py b/llm/fastdeploy_llm/utils/launch_infer.py index eca1198202..2f23aaed5b 100644 --- a/llm/fastdeploy_llm/utils/launch_infer.py +++ b/llm/fastdeploy_llm/utils/launch_infer.py @@ -41,14 +41,17 @@ def launch(device_ids, **kwargs: dict): + "launch_infer.launch, please set them and try again".format( missing_args)) - pd_cmd = "python3 -m paddle.distributed.launch --devices {} {} {}".format( - device_ids, infer_script_path, ' '.join(args)) + #pd_cmd = "python3 -m paddle.distributed.launch --devices {} {} {}".format( + # device_ids, infer_script_path, ' '.join(args)) + pd_cmd = "python3 {} {}".format(infer_script_path, ' '.join(args)) logger.info("Launch model with command: {}".format(pd_cmd)) logger.info("Model is initializing...") + log_home = os.getenv("LOG_HOME", ".") + infer_logger = open('{}/infer.log'.format(log_home), 'a') p = subprocess.Popen( pd_cmd, shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, + stdout=infer_logger, + stderr=infer_logger, preexec_fn=os.setsid) return p diff --git a/llm/fastdeploy_llm/utils/logging_util.py b/llm/fastdeploy_llm/utils/logging_util.py index d34b9dae54..5bb569e612 100644 --- a/llm/fastdeploy_llm/utils/logging_util.py +++ b/llm/fastdeploy_llm/utils/logging_util.py @@ -16,6 +16,7 @@ import logging import threading import time +from enum import Enum from typing import (Any, Generator, Optional, Union) from logging.handlers import TimedRotatingFileHandler @@ -42,6 +43,23 @@ } + + +error_format = """Error: Type {} Code {} Describe: {}""" + +class ErrorCode(Enum): + C0000 = 0 # 客户端发送的query格式错误 + C0001 = 1 # 客户端发送的query有效性校验 + S0000 = 2 # 服务负载过大 + S0001 = 3 # 服务没能正常启动 + S0002 = 4 # 服务退出 + +class ErrorType(Enum): + Query = 0 # Query错误 + Server = 1 # Server错误 + + + class Logger(object): _DEFAULT_NAME: str = 'FastDeploy' @@ -155,3 +173,4 @@ def use_terminator(self, terminator: str) -> Generator[None, None, None]: logger = Logger() +warning_logger = Logger(name="fastDeploy_llm_serving_warning") \ No newline at end of file