From 5f46ff4b0c0e9b5f4d3378a7a19bf3974a19763b Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 24 Jan 2024 02:24:16 -0800 Subject: [PATCH] Fix sglang worker (#2953) --- fastchat/model/model_adapter.py | 2 +- fastchat/serve/sglang_worker.py | 29 +++++++++++++---------------- tests/test_openai_vision_api.py | 1 + 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index acdab09bd..e519e66fb 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -2198,7 +2198,7 @@ def match(self, model_path: str): def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("vicuna_v1.1") - + class YuanAdapter(BaseModelAdapter): """The model adapter for Yuan""" diff --git a/fastchat/serve/sglang_worker.py b/fastchat/serve/sglang_worker.py index 18c4be361..6660be1d8 100644 --- a/fastchat/serve/sglang_worker.py +++ b/fastchat/serve/sglang_worker.py @@ -1,25 +1,20 @@ """ A model worker that executes the model based on SGLANG. + +Usage: +python3 -m fastchat.serve.sglang_worker --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 --worker-address http://localhost:30000 """ import argparse import asyncio import json +import multiprocessing from typing import List from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import StreamingResponse, JSONResponse import uvicorn -from sglang import ( - function, - image, - system, - user, - assistant, - gen, - set_default_backend, - Runtime, -) +import sglang as sgl from sglang.srt.hf_transformers_utils import get_tokenizer, get_config from sglang.srt.utils import load_image @@ -33,14 +28,14 @@ app = FastAPI() -@function +@sgl.function def pipeline(s, prompt, max_tokens): for p in prompt: if isinstance(p, str): s += p else: - s += image(p) - s += gen("response", max_tokens=max_tokens) + s += sgl.image(p) + s += sgl.gen("response", max_tokens=max_tokens) class SGLWorker(BaseModelWorker): @@ -55,7 +50,7 @@ def __init__( limit_worker_concurrency: int, no_register: bool, conv_template: str, - runtime: Runtime, + runtime: sgl.Runtime, trust_remote_code: bool, ): super().__init__( @@ -270,14 +265,16 @@ async def api_model_details(request: Request): args.model_path if args.tokenizer_path == "" else args.tokenizer_path ) - runtime = Runtime( + multiprocessing.set_start_method("spawn", force=True) + runtime = sgl.Runtime( model_path=args.model_path, tokenizer_path=args.tokenizer_path, trust_remote_code=args.trust_remote_code, mem_fraction_static=args.mem_fraction_static, tp_size=args.tp_size, + log_level="info", ) - set_default_backend(runtime) + sgl.set_default_backend(runtime) worker = SGLWorker( args.controller_address, diff --git a/tests/test_openai_vision_api.py b/tests/test_openai_vision_api.py index b1eb0ac8b..a54d7d575 100644 --- a/tests/test_openai_vision_api.py +++ b/tests/test_openai_vision_api.py @@ -1,5 +1,6 @@ """ Test the OpenAI compatible server + Launch: python3 launch_openai_api_test_server.py --multimodal """