diff --git a/README.md b/README.md index 0328805..61ee3b1 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,46 @@ We provide a simple example to run inference on a Huggingface LLM model. The scr CUDA_VISIBLE_DEVICES=0 python example/interface_example.py --model_name_or_path "mistralai/Mixtral-8x7B-Instruct-v0.1" --offload_dir ``` +### OpenAI-Compatible Server + +Start the OpenAI-compatible server locally +```bash +python -m moe_infinity.entrypoints.openai.api_server --model facebook/opt-125m --offload-dir ./offload_dir +``` + +Query the model via `/v1/components/`. (We currently only support the required fields, i.e., "model" and "prompt"). +```bash +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "facebook/opt-125m", + "prompt": "Hello, my name is" + }' +``` +You can also use `openai` python package to query the model. +```bash +pip install openai +python tests/test_oai_completions.py +``` + +Query the model via `/v1/chat/completions`. (We currently only support the required fields, i.e., "model" and "messages"). +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "facebook/opt-125m", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke"} + ] + }' +``` +You can also use `openai` python package to query the model. +```bash +pip install openai +python tests/test_oai_chat_completions.py +``` + ## Release Plan We plan to release two functions in the following months: diff --git a/moe_infinity/entrypoints/openai/__init__.py b/moe_infinity/entrypoints/openai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/moe_infinity/entrypoints/openai/api_server.py b/moe_infinity/entrypoints/openai/api_server.py new file mode 100644 index 0000000..d154c79 --- /dev/null +++ b/moe_infinity/entrypoints/openai/api_server.py @@ -0,0 +1,286 @@ +# Copyright 2024 TorchMoE Team + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file includes source code adapted from vLLM +# (https://github.com/vllm-project/vllm), +# which is also licensed under the Apache License, Version 2.0. + +import argparse +import asyncio +import json +import os +import time +from typing import Tuple +from queue import Queue + +from transformers import TextStreamer +from moe_infinity import MoE +import torch +from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration + +import fastapi +import uvicorn +from fastapi import Request +from fastapi.responses import JSONResponse, StreamingResponse, Response + +# from moe_infinity.entrypoints.openai.protocol import ( +# ChatCompletionRequest, ChatCompletionResponse, +# ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, +# ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, +# CompletionRequest, CompletionResponse, +# CompletionResponseChoice, +# ModelPermission, ModelCard, ModelList, +# UsageInfo) +# from moe_infinity.entrypoints.openai.protocol import random_uuid + +from protocol import ( + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatMessage, DeltaMessage, ErrorResponse, + CompletionRequest, CompletionResponse, + CompletionResponseChoice, + ModelPermission, ModelCard, ModelList, + UsageInfo) +from protocol import random_uuid + +TIMEOUT_KEEP_ALIVE = 5 # seconds +# logger = init_logger(__name__) +model_name = None +model = None +tokenizer = None +model_queue = None + +app = fastapi.FastAPI() + + +class TokenStreamer(TextStreamer): + def __init__(self, tokenizer): + super().__init__(tokenizer) + self.token_cache = [] + self.encoded = False + + def put(self, value): + if self.encoded and value is not None: + self.token_cache.append(value) + else: + self.encoded = True + + def end(self): + pass + + def get_tokens(self): + return self.token_cache + + +def parse_prompt_format(prompt) -> Tuple[bool, list]: + # get the prompt, openai supports the following + # "a string, array of strings, array of tokens, or array of token arrays." + prompt_is_tokens = False + prompts = [prompt] # case 1: a string + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + elif isinstance(prompt[0], str): + prompt_is_tokens = False + prompts = prompt # case 2: array of strings + elif isinstance(prompt[0], int): + prompt_is_tokens = True + prompts = [prompt] # case 3: array of tokens + elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int): + prompt_is_tokens = True + prompts = prompt # case 4: array of token arrays + else: + raise ValueError( + "prompt must be a string, array of strings, array of tokens, or array of token arrays" + ) + return prompt_is_tokens, prompts + + +def get_available_models() -> ModelList: + """Show available models. Right now we only have one model.""" + model_cards = [ + ModelCard(id=model_name, + root=model_name, + permission=[ModelPermission()]) + ] + return ModelList(data=model_cards) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="MoE-Infinity OpenAI-Compatible RESTful API server.") + parser.add_argument("--host", type=str, default=None, help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--offload-dir", type=str, required=True) + parser.add_argument("--device-memory-ratio", type=float, default=0.75) + + return parser.parse_args() + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +# @app.get("/v1/models") +# async def show_available_models(): +# models = get_available_models() +# return JSONResponse(content=models.model_dump()) + + +@app.post("/v1/chat/completions") +async def chat_completion(request: ChatCompletionRequest, + raw_request: Request): + model_name = request.model + created_time = int(time.monotonic()) + request_id = random_uuid() + + prompt = tokenizer.apply_chat_template( + conversation=request.messages, + tokenize=False, + ) + print(f"prompt: {prompt}") + + token_ids = tokenizer.encode(prompt, return_tensors="pt") + token_ids = token_ids.to("cuda:0") + print(f"token_ids: {token_ids}") + num_prompt_tokens = token_ids.size(1) + + streamer = TokenStreamer(tokenizer) + + token = model_queue.get() + _ = model.generate( + token_ids, + streamer=streamer, + **request.to_hf_params() + ) + model_queue.put(token) + + outputs = torch.tensor(streamer.get_tokens()) + print(f"outputs: {outputs}") + num_generated_tokens = len(outputs) + + final_res = tokenizer.decode(outputs, skip_special_tokens=True) + print(f"final_res: {final_res}") + + choices = [] + # role = self.get_chat_request_role(request) + role = "assistant" # FIXME: hardcoded + # for output in final_res.outputs: + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role=role, content=final_res), + finish_reason="stop", + ) + choices.append(choice_data) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens + ) + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + return response + + +@app.post("/v1/completions") +async def completion(request: CompletionRequest, raw_request: Request): + model_name = request.model + created_time = int(time.monotonic()) + request_id = random_uuid() + + prompt_is_tokens, prompts = parse_prompt_format(request.prompt) + + choices = [] + num_prompt_tokens = 0 + num_generated_tokens = 0 + for i, prompt in enumerate(prompts): + if prompt_is_tokens: + input_ids = prompt + else: + input_ids = tokenizer.encode(prompt, return_tensors="pt") + + input_ids = input_ids.to("cuda:0") + streamer = TokenStreamer(tokenizer) + + token = model_queue.get() + _ = model.generate( + input_ids, + streamer=streamer, + **request.to_hf_params() + ) + model_queue.put(token) + + outputs = torch.tensor(streamer.get_tokens()) + final_res = tokenizer.decode(outputs, skip_special_tokens=True) + + output_text = final_res + + choice_data = CompletionResponseChoice( + index=i, + text=output_text, + logprobs=None, + finish_reason="stop", + ) + choices.append(choice_data) + + num_prompt_tokens += input_ids.size(1) + num_generated_tokens += len(outputs) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + + return CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + +if __name__ == "__main__": + args = parse_args() + + print(f"args: {args}") + + model_name = args.model + tokenizer = AutoTokenizer.from_pretrained("google/switch-base-16") + + config = { + "offload_path": os.path.join(args.offload_dir, model_name), + "device_memory_ratio": args.device_memory_ratio, + } + model = MoE("google/switch-base-16", config) + # model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-16", torch_dtype=torch.float16) + # model = OPTForCausalLM.from_pretrained(model_name) + model_queue = Queue() + model_queue.put("token") + + uvicorn.run(app, + host=args.host, + port=args.port, + log_level="info", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE) diff --git a/moe_infinity/entrypoints/openai/protocol.py b/moe_infinity/entrypoints/openai/protocol.py new file mode 100644 index 0000000..85a81e5 --- /dev/null +++ b/moe_infinity/entrypoints/openai/protocol.py @@ -0,0 +1,203 @@ +# Copyright 2024 TorchMoE Team + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file includes source code adapted from vLLM +# (https://github.com/vllm-project/vllm), +# which is also licensed under the Apache License, Version 2.0. + +# Adapted from +# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import time +import uuid +from typing import Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, Field + +def random_uuid(): + return str(uuid.uuid4()) + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: int + + +class ModelPermission(BaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = False + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: str = False + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "torchmoe" + root: Optional[str] = None + parent: Optional[str] = None + permission: List[ModelPermission] = Field(default_factory=list) + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class ChatCompletionRequest(BaseModel): + model: str + messages: Union[str, List[Dict[str, str]]] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + max_tokens: Optional[int] = None + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + + def to_hf_params(self) -> Dict[str, Union[str, int, float, List[int], List[str]]]: + return { + "temperature": self.temperature, + "top_p": self.top_p, + "logit_bias": self.logit_bias, + } + + +class CompletionRequest(BaseModel): + model: str + # a string, array of strings, array of tokens, or array of token arrays + prompt: Union[List[int], List[List[int]], str, List[str]] + suffix: Optional[str] = None + max_tokens: Optional[int] = 16 + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + stream: Optional[bool] = False + logprobs: Optional[int] = None + echo: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + best_of: Optional[int] = None + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + + def to_hf_params(self) -> Dict[str, Union[str, int, float, List[int], List[str]]]: + echo_without_generation = self.echo and self.max_tokens == 0 + + return { + "temperature": self.temperature, + "top_p": self.top_p, + "logit_bias": self.logit_bias, + "best_of": self.best_of + } + + +class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) diff --git a/requirements.txt b/requirements.txt index 6b5ecc7..a549a2f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,5 @@ scipy chardet optimum>=1.17.1 auto_gptq +fastapi==2.16.2 +uvicorn==0.28.0 \ No newline at end of file diff --git a/tests/test_oai_chat_completions.py b/tests/test_oai_chat_completions.py new file mode 100644 index 0000000..930d136 --- /dev/null +++ b/tests/test_oai_chat_completions.py @@ -0,0 +1,18 @@ +from openai import OpenAI + +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +chat_response = client.chat.completions.create( + model="facebook/opt-125m", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke"}, + ] +) +print("Chat response:", chat_response) \ No newline at end of file diff --git a/tests/test_oai_completions.py b/tests/test_oai_completions.py new file mode 100644 index 0000000..396ce01 --- /dev/null +++ b/tests/test_oai_completions.py @@ -0,0 +1,11 @@ +from openai import OpenAI + +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) +completion = client.completions.create(model="facebook/opt-125m", + prompt="Hello, my name is") +print("Completion result:", completion) \ No newline at end of file