Skip to content

Commit

Permalink
Fix sglang worker (lm-sys#2953)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Jan 24, 2024
1 parent df81798 commit 5f46ff4
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 17 deletions.
2 changes: 1 addition & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2198,7 +2198,7 @@ def match(self, model_path: str):
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("vicuna_v1.1")


class YuanAdapter(BaseModelAdapter):
"""The model adapter for Yuan"""

Expand Down
29 changes: 13 additions & 16 deletions fastchat/serve/sglang_worker.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
"""
A model worker that executes the model based on SGLANG.
Usage:
python3 -m fastchat.serve.sglang_worker --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 --worker-address http://localhost:30000
"""

import argparse
import asyncio
import json
import multiprocessing
from typing import List

from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
from sglang import (
function,
image,
system,
user,
assistant,
gen,
set_default_backend,
Runtime,
)
import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenizer, get_config
from sglang.srt.utils import load_image

Expand All @@ -33,14 +28,14 @@
app = FastAPI()


@function
@sgl.function
def pipeline(s, prompt, max_tokens):
for p in prompt:
if isinstance(p, str):
s += p
else:
s += image(p)
s += gen("response", max_tokens=max_tokens)
s += sgl.image(p)
s += sgl.gen("response", max_tokens=max_tokens)


class SGLWorker(BaseModelWorker):
Expand All @@ -55,7 +50,7 @@ def __init__(
limit_worker_concurrency: int,
no_register: bool,
conv_template: str,
runtime: Runtime,
runtime: sgl.Runtime,
trust_remote_code: bool,
):
super().__init__(
Expand Down Expand Up @@ -270,14 +265,16 @@ async def api_model_details(request: Request):
args.model_path if args.tokenizer_path == "" else args.tokenizer_path
)

runtime = Runtime(
multiprocessing.set_start_method("spawn", force=True)
runtime = sgl.Runtime(
model_path=args.model_path,
tokenizer_path=args.tokenizer_path,
trust_remote_code=args.trust_remote_code,
mem_fraction_static=args.mem_fraction_static,
tp_size=args.tp_size,
log_level="info",
)
set_default_backend(runtime)
sgl.set_default_backend(runtime)

worker = SGLWorker(
args.controller_address,
Expand Down
1 change: 1 addition & 0 deletions tests/test_openai_vision_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
Test the OpenAI compatible server
Launch:
python3 launch_openai_api_test_server.py --multimodal
"""
Expand Down

0 comments on commit 5f46ff4

Please sign in to comment.