Skip to content

Commit

Permalink
Add fastapi support
Browse files Browse the repository at this point in the history
  • Loading branch information
rainyfly committed Feb 18, 2024
1 parent 7bddc67 commit 23d877f
Show file tree
Hide file tree
Showing 4 changed files with 386 additions and 0 deletions.
281 changes: 281 additions & 0 deletions llm/fastdeploy_llm/server/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import uuid
import threading
import time
import numpy as np
import functools
from collections import defaultdict
import queue
import asyncio

from fastapi import Request, HTTPException
from fastapi.responses import Response, JSONResponse
import google.protobuf.text_format as text_format
import google.protobuf.json_format as json_format
from tritonclient.grpc.model_config_pb2 import ModelConfig

from fastdeploy_llm.serving.serving_model import ServingModel
from fastdeploy_llm.utils.logging_util import logger, warning_logger
from fastdeploy_llm.utils.logging_util import error_format, ErrorCode, ErrorType
from fastdeploy_llm.task import Task, BatchTask
import fastdeploy_llm as fdlm

def pbtxt2json(content: str):
'''
Convert protocol messages in text format to json format string.
'''
message = text_format.Parse(content, ModelConfig())
json_string = json_format.MessageToJson(message)
return json_string


request_start_time_dict = {}
response_dict = {}
event_dict = {}
response_checked_dict = {}


def stream_call_back(call_back_task, token_tuple, index, is_last_token,
sender):
if is_last_token:
all_token_ids = [t[0] for t in call_back_task.result.completion_tokens] + [token_tuple[0]]
all_strs = "".join([t[1] for t in call_back_task.result.completion_tokens]) + token_tuple[1]
out = dict()
out["result"] = all_strs
out["req_id"] = call_back_task.task_id
out["token_ids"] = all_token_ids
out['send_idx'] = 0 # 整句返回
out["is_end"] = True
response_dict[call_back_task.task_id] = out
logger.info("Model output for req_id: {} results_all: {} tokens_all: {} inference_cost_time: {} ms".format(
call_back_task.task_id, all_strs, all_token_ids, (time.time() - call_back_task.inference_start_time) * 1000))


def parse(parameters_config, name, default_value=None):
if name not in parameters_config:
if default_value:
return default_value
else:
raise Exception(
"Cannot find key:{} while parsing parameters.".format(name))
return parameters_config[name]["stringValue"]


class ModelExecutor:
def __init__(self, model_dir):
config = fdlm.Config(model_dir)
config_pb_path = os.path.join(model_dir, 'config.pbtxt')
if os.path.exists(config_pb_path):
with open(config_pb_path, 'r') as f:
data = f.read()
json_str = pbtxt2json(data)
parameters = json.loads(json_str)['parameters']
config.max_batch_size = int(parse(parameters, "MAX_BATCH_SIZE", 4))
config.mp_num = int(parse(parameters, "MP_NUM", 1))
if config.mp_num < 0:
config.mp_num = None
config.max_dec_len = int(parse(parameters, "MAX_DEC_LEN", 1024))
config.max_seq_len = int(parse(parameters, "MAX_SEQ_LEN", 1024))
config.decode_strategy = parse(parameters, "DECODE_STRATEGY",
"sampling")
config.stop_threshold = int(parse(parameters, "STOP_THRESHOLD", 2))
config.disable_dynamic_batching = int(
parse(parameters, "DISABLE_DYNAMIC_BATCHING", 0))
config.max_queue_num = int(parse(parameters, "MAX_QUEUE_NUM", 512))
config.is_ptuning = int(parse(parameters, "IS_PTUNING", 0))
if config.is_ptuning:
config.model_prompt_dir_path = parse(parameters,
"MODEL_PROMPT_DIR_PATH")
config.max_prefix_len = int(parse(parameters, "MAX_PREFIX_LEN"))
config.load_environment_variables()

self.wait_time_out = 60
self.config = config
self.response_handler = dict()
self.model = None


def prepare_model(self):
# This method can only called once within all process
self.model = ServingModel(self.config)
self.model.model.stream_sender = self.response_handler
self.model.start()

def execute(self, req_dict):
if self.model is None:
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "Model is not ready"
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
response_dict[req_dict['req_id']] = error_msg
return
# 1. validate the deserializing process
task = Task()
try:
task.from_dict(req_dict)
request_start_time = time.time()
task.set_request_start_time(request_start_time)
except Exception as e:
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "There's error while deserializing data from request, received data = {} error={}".format(req_dict, e)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
response_dict[req_dict['req_id']] = error_msg
return

# 3. check if exists task id conflict
if task.task_id is None:
task.task_id = str(uuid.uuid4())
request_start_time_dict[task.task_id] = request_start_time
if task.task_id in self.response_handler:
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "Task id conflict with {}.".format(task.task_id)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
response_dict[req_dict['req_id']] = error_msg
return

# 4. validate the parameters in task
try:
task.check(self.config.max_dec_len)
except Exception as e:
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "There's error while checking task, task={} error={}".format(task, e)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
response_dict[req_dict['req_id']] = error_msg
return

# 5. check if the requests queue is full
if self.model.requests_queue.qsize() > self.config.max_queue_num:
error_type = ErrorType.Server
error_code = ErrorCode.S0000
error_info = "The queue is full now(size={}), please wait for a while.".format(self.config.max_queue_num)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
response_dict[req_dict['req_id']] = error_msg
return

# 6. check if the prefix embedding is exist
if self.config.is_ptuning and task.model_id is not None:
np_file_path = os.path.join(self.config.model_prompt_dir_path,
"8-{}".format(task.model_id), "1",
"task_prompt_embeddings.npy")
if not os.path.exists(np_file_path):
response_dict[req_dict['req_id']] = error_msg
return

# 7. Add task to requests queue
task.call_back_func = stream_call_back
try:
self.model.add_request(task)
except queue.Full as e:
# Log error for Server
error_type = ErrorType.Server
error_code = ErrorCode.S0000
error_info = "The queue is full now(size={}), please scale service.".format(self.config.max_queue_num)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
# Log error for query
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "There's error while inserting new request, task={} error={}".format(task, "service too busy")
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
response_dict[req_dict['req_id']] = error_msg
return

except Exception as e:
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "There's error while inserting new request, task={} error={}".format(task, e)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
response_dict[req_dict['req_id']] = error_msg
return

return task

async def inference(self, request_in: Request):
"""
API for generation task.
"""
start_time = time.time()
try:
input_dict = await request_in.json()
logger.info("recieved req_dict {}".format(input_dict))
except:
error_type = ErrorType.Query
error_code = ErrorCode.C0001
content = await request_in.body()
error_info = "request body is not a valid json format, received data = {}".format(content)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
task = self.execute(input_dict)
event_dict[task.task_id] = asyncio.Event()
try:
await asyncio.wait_for(event_dict[task.task_id].wait(), self.wait_time_out)
except:
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "Timeout for getting inference result."
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
result = response_checked_dict[task.task_id]
del response_checked_dict[task.task_id]
del event_dict[task.task_id]
logger.info("req_id: {} has sent back to client, request_cost_time: {} ms".format(task.task_id, (time.time() - start_time) * 1000))
return JSONResponse(result)

def check_live(self):
"""
API for detecting http app status.
"""
if self.model.model._is_engine_initialized() and (self.model.model.engine_proc.poll() is None):
logger.info("check_live: True")
return Response(status_code=200)
else:
logger.info("check_live: False")
return Response(status_code=500)


async def watch_result():
while True:
await asyncio.sleep(0.01) # 10ms查询一次结果
if response_dict:
for task_id in response_dict:
response_checked_dict[task_id] = response_dict[task_id]
event_dict[task_id].set()

for task_id in response_checked_dict:
del response_dict[task_id]


model_dir = os.getenv("MODEL_DIR", None)
if model_dir is None:
raise ValueError("Environment variable MODEL_DIR must be set")
model_executor = ModelExecutor(model_dir)



49 changes: 49 additions & 0 deletions llm/fastdeploy_llm/server/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Http FastAPI app
"""
import asyncio
import subprocess
import os

import uvicorn
from fastapi import FastAPI, APIRouter

from .api import watch_result, model_executor

check_live_path = '/ready'
inference_path = '/v1/chat/completions'

def create_app():
"""
Create a FastAPI app.
"""
app = FastAPI()
router = APIRouter()
url_mappings = [
(check_live_path, model_executor.check_live, ["GET"]),
(inference_path, model_executor.inference, ["POST"]),
]
for url, view_func, supported_methods in url_mappings:
router.add_api_route(url, endpoint=view_func, methods=supported_methods)
app.include_router(router)
return app

app = create_app()

# FastAPI 的启动事件
@app.on_event("startup")
async def startup_event():
"""
监控结果是否产生
"""
model_executor.prepare_model()
watch_result_task = asyncio.create_task(watch_result())

def run(args):
"""
start http server
"""
uvicorn.run("fastdeploy_llm.server.app:app", host="0.0.0.0", port=int(args.http_port), log_level="info")



45 changes: 45 additions & 0 deletions llm/fastdeploy_llm/server/cmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
Command line entrypoint for scheduler
"""
import argparse
import multiprocessing
import os
import signal
import json

from .env import fastdeploy_llm_home
from .env import pgid_file_path


if not os.path.exists(fastdeploy_llm_home):
os.mkdir(fastdeploy_llm_home)


def main():
"""
Main function.
"""
parser = argparse.ArgumentParser("Fastdeploy llm launcher")
parser.add_argument("--http-port", type=int, default=8100)
parser.add_argument('cmd', nargs='?', help='command for fastdeploy_llm')
args = parser.parse_args()
if args.cmd == 'stop':
if os.path.exists(pgid_file_path):
with open(pgid_file_path, 'r') as f:
pgid = f.read().strip()
# 发送 SIGTERM 信号以停止服务
os.killpg(int(pgid), signal.SIGTERM)
return
if os.getenv("MODEL_DIR", None) is None:
raise ValueError("Environment variable MODEL_DIR must be set")
from .app import run
# 服务启动时重置服务需要的资源文件
if os.path.exists(pgid_file_path):
os.remove(pgid_file_path)
with open(pgid_file_path, 'w') as f:
f.write(str(os.getpgid(os.getpid()))) # 获取进程组pgid
run(args)

if __name__ == '__main__':
os.setsid()
main()
11 changes: 11 additions & 0 deletions llm/fastdeploy_llm/server/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
保存服务有关的内部用的环境资源
"""
import os

# 服务资源的HOME目录
fastdeploy_llm_home = os.path.join(os.path.expanduser('~'), '.fastdeploy_llm')

# 用于保存服务进程的 PGID
pgid_file_path = os.path.join(fastdeploy_llm_home, 'fastdeploy_llm.pgid')

0 comments on commit 23d877f

Please sign in to comment.