Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
cyx-6 committed Nov 20, 2024
1 parent 1865b55 commit 14c92ce
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 10 deletions.
49 changes: 41 additions & 8 deletions python/mlc_llm/bench/api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,23 @@ def __init__( # pylint: disable=too-many-arguments
self,
host: str,
port: int,
backend: str,
timeout: Optional[float] = None,
include_server_metrics: bool = False,
no_debug_config: bool = False,
) -> None:
super().__init__(include_server_metrics=include_server_metrics)

import aiohttp # pylint: disable=import-outside-toplevel,import-error

self.backend = backend
self.timeout = timeout
self.client: aiohttp.ClientSession = None
self.url = f"http://{host}:{port}/v1/chat/completions"
self.headers = {"Content-Type": "application/json"}
if os.getenv("MLC_LLM_API_KEY"):
self.headers["Authorization"] = f"Bearer {os.getenv('MLC_LLM_API_KEY')}"
self.no_debug_config = no_debug_config

async def __aenter__(self) -> Self:
import aiohttp # pylint: disable=import-outside-toplevel,import-error
Expand All @@ -80,13 +84,28 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements,too
and request_record.chat_cmpl.debug_config.ignore_eos
):
payload["ignore_eos"] = True
if not self.no_debug_config:
payload["debug_config"] = {"ignore_eos": True}

print(payload)

if "response_format" in payload and "json_schema" in payload["response_format"]:
payload["response_format"]["schema"] = payload["response_format"]["json_schema"]
payload["response_format"].pop("json_schema")

if self.backend == "vllm":
if payload["debug_config"] and "ignore_eos" in payload["debug_config"]:
payload["ignore_eos"] = payload["debug_config"]["ignore_eos"]
payload.pop("debug_config")
if "response_format" in payload:
if "json_schema" in payload["response_format"]:
payload["guided_json"] = json.loads(payload["response_format"]["json_schema"])
payload["guided_decoding_backend"] = "outlines"
payload.pop("response_format")
elif self.backend == "llama.cpp":
if "response_format" in payload and "schema" in payload["response_format"]:
payload["response_format"]["schema"] = json.loads(
payload["response_format"]["json_schema"]
)
payload["response_format"].pop("json_schema")
else:
if "response_format" in payload and "json_schema" in payload["response_format"]:
payload["response_format"]["schema"] = payload["response_format"]["json_schema"]
payload["response_format"].pop("json_schema")
generated_text = ""
first_chunk_output_str = ""
time_to_first_token_s = None
Expand Down Expand Up @@ -447,19 +466,33 @@ async def __call__( # pylint: disable=too-many-branches,too-many-locals,too-man
"sglang",
"tensorrt-llm",
"vllm",
"vllm-chat",
"llama.cpp-chat",
]


def create_api_endpoint(args: argparse.Namespace) -> APIEndPoint:
"""Create an API endpoint instance with regard to the specified endpoint kind."""
if args.api_endpoint in ["openai", "mlc", "sglang"]:
return OpenAIEndPoint(args.host, args.port, args.timeout, args.include_server_metrics)
if args.api_endpoint == "vllm":
if args.api_endpoint in ["vllm", "llama.cpp"]:
return OpenAIEndPoint(
args.host, args.port, args.timeout, include_server_metrics=False, no_debug_config=True
)
if args.api_endpoint == "openai-chat":
return OpenAIChatEndPoint(args.host, args.port, args.timeout, args.include_server_metrics)
return OpenAIChatEndPoint(
args.host, args.port, args.timeout, args.api_endpoint, args.include_server_metrics
)
if args.api_endpoint in ["vllm-chat", "llama.cpp-chat"]:
return OpenAIChatEndPoint(
args.host,
args.port,
args.api_endpoint[:-5],
args.timeout,
include_server_metrics=False,
no_debug_config=True,
)

if args.api_endpoint == "tensorrt-llm":
return TensorRTLLMEndPoint(args.host, args.port, args.timeout)
raise ValueError(f'Unrecognized endpoint "{args.api_endpoint}"')
7 changes: 6 additions & 1 deletion python/mlc_llm/bench/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,12 @@ def __init__(self, tokenizer: AutoTokenizer) -> None:
self.dataset = []
for data in raw_dataset:
data = self._process_data(data)
messages = data["prompt"]
messages = [
{
"content": data["prompt"][0]["content"] + " " + data["prompt"][1]["content"],
"role": data["prompt"][1]["role"],
},
]
schema = {
"type": "json_object",
"schema": data["schema"],
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/bench/request_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
request_record.chat_cmpl.top_p = self.top_p
request_record.chat_cmpl.frequency_penalty = 0.0
request_record.chat_cmpl.presence_penalty = 0.0
request_record.chat_cmpl.tool_choice = "none"
request_record.chat_cmpl.tool_choice = None
if self.ignore_eos:
request_record.chat_cmpl.debug_config = DebugConfig(ignore_eos=True)
return request_records
Expand Down

0 comments on commit 14c92ce

Please sign in to comment.