diff --git a/aana/api/request_handler.py b/aana/api/request_handler.py index 2dd508c6..cd9df7ee 100644 --- a/aana/api/request_handler.py +++ b/aana/api/request_handler.py @@ -1,6 +1,11 @@ +import json +import time from typing import Any +from uuid import uuid4 +import ray from fastapi.openapi.utils import get_openapi +from fastapi.responses import StreamingResponse from ray import serve from aana.api.api_generation import Endpoint, add_custom_schemas_to_openapi_schema @@ -8,7 +13,10 @@ from aana.api.event_handlers.event_manager import EventManager from aana.api.responses import AanaJSONResponse from aana.configs.settings import settings as aana_settings +from aana.core.models.chat import ChatCompletetion, ChatCompletionRequest, ChatDialog +from aana.core.models.sampling import SamplingParams from aana.core.models.task import TaskId +from aana.deployments.aana_deployment_handle import AanaDeploymentHandle from aana.storage.services.task import TaskInfo, delete_task, get_task_info @@ -124,3 +132,75 @@ async def delete_task_endpoint(self, task_id: str) -> TaskId: """ task = delete_task(task_id) return TaskId(task_id=str(task.id)) + + @app.post("/chat/completions", response_model=ChatCompletetion) + async def chat_completions(self, request: ChatCompletionRequest): + """Handle chat completions requests for OpenAI compatible API.""" + + async def _async_chat_completions( + handle: AanaDeploymentHandle, + dialog: ChatDialog, + sampling_params: SamplingParams, + ): + async for response in handle.chat_stream( + dialog=dialog, sampling_params=sampling_params + ): + chunk = { + "id": f"chatcmpl-{uuid4().hex}", + "object": "chat.completion.chunk", + "model": request.model, + "created": int(time.time()), + "choices": [ + { + "index": 0, + "delta": {"content": response["text"], "role": "assistant"}, + } + ], + } + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + + # Check if the deployment exists + try: + handle = await AanaDeploymentHandle.create(request.model) + except ray.serve.exceptions.RayServeException: + return AanaJSONResponse( + content={ + "error": {"message": f"The model `{request.model}` does not exist."} + }, + status_code=404, + ) + + # Check if the deployment is a chat model + if not hasattr(handle, "chat") or not hasattr(handle, "chat_stream"): + return AanaJSONResponse( + content={ + "error": {"message": f"The model `{request.model}` does not exist."} + }, + status_code=404, + ) + + dialog = ChatDialog( + messages=request.messages, + ) + + sampling_params = SamplingParams( + temperature=request.temperature, + max_tokens=request.max_tokens, + top_p=request.top_p, + ) + + if request.stream: + return StreamingResponse( + _async_chat_completions(handle, dialog, sampling_params), + media_type="application/x-ndjson", + ) + else: + response = await handle.chat(dialog=dialog, sampling_params=sampling_params) + return { + "id": f"chatcmpl-{uuid4().hex}", + "object": "chat.completion", + "model": request.model, + "created": int(time.time()), + "choices": [{"index": 0, "message": response["message"]}], + } diff --git a/aana/core/models/chat.py b/aana/core/models/chat.py index 56bbc4d1..494b9258 100644 --- a/aana/core/models/chat.py +++ b/aana/core/models/chat.py @@ -83,3 +83,73 @@ def from_list(cls, messages: list[dict[str, str]]) -> "ChatDialog": ChatDialog: the chat dialog """ return ChatDialog(messages=[ChatMessage(**message) for message in messages]) + + +class ChatCompletionRequest(BaseModel): + """A chat completion request for OpenAI compatible API.""" + + model: str = Field(..., description="The model name (name of the LLM deployment).") + messages: list[ChatMessage] = Field( + ..., description="A list of messages comprising the conversation so far." + ) + temperature: float | None = Field( + default=None, + ge=0.0, + description=( + "Float that controls the randomness of the sampling. " + "Lower values make the model more deterministic, " + "while higher values make the model more random. " + "Zero means greedy sampling." + ), + ) + top_p: float | None = Field( + default=None, + gt=0.0, + le=1.0, + description=( + "Float that controls the cumulative probability of the top tokens to consider. " + "Must be in (0, 1]. Set to 1 to consider all tokens." + ), + ) + max_tokens: int | None = Field( + default=None, ge=1, description="The maximum number of tokens to generate." + ) + + stream: bool | None = Field( + default=False, + description=( + "If set, partial message deltas will be sent, like in ChatGPT. " + "Tokens will be sent as data-only server-sent events as they become available, " + "with the stream terminated by a data: [DONE] message." + ), + ) + + +class ChatCompletetionChoice(BaseModel): + """A chat completion choice for OpenAI compatible API.""" + + index: int = Field( + ..., description="The index of the choice in the list of choices." + ) + message: ChatMessage = Field( + ..., description="A chat completion message generated by the model." + ) + + +class ChatCompletetion(BaseModel): + """A chat completion for OpenAI compatible API.""" + + id: str = Field(..., description="A unique identifier for the chat completion.") + model: str = Field(..., description="The model used for the chat completion.") + created: int = Field( + ..., + description="The Unix timestamp (in seconds) of when the chat completion was created.", + ) + choices: list[ChatCompletetionChoice] = Field( + ..., + description="A list of chat completion choices.", + ) + object: Literal["chat.completion"] = Field( + "chat.completion", + description="The object type, which is always `chat.completion`.", + ) diff --git a/aana/tests/units/test_chat_completion.py b/aana/tests/units/test_chat_completion.py new file mode 100644 index 00000000..1a146da5 --- /dev/null +++ b/aana/tests/units/test_chat_completion.py @@ -0,0 +1,110 @@ +# ruff: noqa: S101, S113 +from collections.abc import AsyncGenerator + +import pytest +import requests +from openai import NotFoundError, OpenAI +from ray import serve + +from aana.core.models.chat import ChatDialog, ChatMessage +from aana.core.models.sampling import SamplingParams +from aana.deployments.base_text_generation_deployment import ( + BaseTextGenerationDeployment, + ChatOutput, + LLMOutput, +) + + +@serve.deployment +class LowercaseLLM(BaseTextGenerationDeployment): + """Ray deployment that returns the lowercase version of a text structured as an LLM.""" + + async def generate_stream( + self, prompt: str, sampling_params: SamplingParams | None = None + ) -> AsyncGenerator[LLMOutput, None]: + """Generate text stream. + + Args: + prompt (str): The prompt. + sampling_params (SamplingParams): The sampling parameters. + + Yields: + LLMOutput: The generated text. + """ + for char in prompt: + yield LLMOutput(text=char.lower()) + + async def chat( + self, dialog: ChatDialog, sampling_params: SamplingParams | None = None + ) -> ChatOutput: + """Dummy chat method.""" + text = dialog.messages[-1].content + return ChatOutput(message=ChatMessage(content=text.lower(), role="assistant")) + + async def chat_stream( + self, dialog: ChatDialog, sampling_params: SamplingParams | None = None + ) -> AsyncGenerator[LLMOutput, None]: + """Dummy chat stream method.""" + text = dialog.messages[-1].content + for char in text: + yield LLMOutput(text=char.lower()) + + +deployments = [ + { + "name": "lowercase_deployment", + "instance": LowercaseLLM, + } +] + + +def test_chat_completion(app_setup): + """Test the chat completion endpoint for OpenAI compatible API.""" + aana_app = app_setup(deployments, []) + + port = aana_app.port + route_prefix = "" + + # Check that the server is ready + response = requests.get(f"http://localhost:{port}{route_prefix}/api/ready") + assert response.status_code == 200 + assert response.json() == {"ready": True} + + messages = [ + {"role": "user", "content": "Hello World!"}, + ] + expected_output = messages[0]["content"].lower() + + client = OpenAI( + api_key="token", + base_url=f"http://localhost:{port}", + ) + + # Test chat completion endpoint + completion = client.chat.completions.create( + messages=messages, + model="lowercase_deployment", + ) + assert completion.choices[0].message.content == expected_output + + # Test chat completion endpoint with stream + stream = client.chat.completions.create( + messages=messages, + model="lowercase_deployment", + stream=True, + ) + generated_text = "" + for chunk in stream: + generated_text += chunk.choices[0].delta.content or "" + assert generated_text == expected_output + + # Test chat completion endpoint with non-existent model + with pytest.raises(NotFoundError) as exc_info: + completion = client.chat.completions.create( + messages=messages, + model="non_existent_model", + ) + assert ( + exc_info.value.body["message"] + == "The model `non_existent_model` does not exist." + ) diff --git a/docs/integrations.md b/docs/integrations.md index 1e59c736..1cf57f67 100644 --- a/docs/integrations.md +++ b/docs/integrations.md @@ -64,3 +64,7 @@ HfPipelineDeployment.options( Haystack integration allows you to build Retrieval-Augmented Generation (RAG) systems with the [Deepset Haystack](https://github.com/deepset-ai/haystack). TODO: Add example + +## OpenAI-compatible Chat Completions API + +The OpenAI-compatible Chat Completions API allows you to access the Aana applications with any OpenAI-compatible client. See [OpenAI-compatible API docs](/docs/openai_api.md) for more details. diff --git a/docs/openai_api.md b/docs/openai_api.md new file mode 100644 index 00000000..a9da2fd3 --- /dev/null +++ b/docs/openai_api.md @@ -0,0 +1,95 @@ +# OpenAI-compatible API + +Aana SDK provides an OpenAI-compatible Chat Completions API that allows you to integrate Aana with any OpenAI-compatible application. + +Chat Completions API is available at the `/chat/completions` endpoint. + +It is compatible with the OpenAI client libraries and can be used as a drop-in replacement for OpenAI API. + +```python +from openai import OpenAI + +client = OpenAI( + api_key="token", # Any non empty string will work, we don't require an API key + base_url="http://localhost:8000", +) + +messages = [ + {"role": "user", "content": "What is the capital of France?"} +] + +completion = client.chat.completions.create( + messages=messages, + model="llm_deployment", +) + +print(completion.choices[0].message.content) +``` + +The API also supports streaming: + +```python +from openai import OpenAI + +client = OpenAI( + api_key="token", # Any non empty string will work, we don't require an API key + base_url="http://localhost:8000", +) + +messages = [ + {"role": "user", "content": "What is the capital of France?"} +] + +stream = client.chat.completions.create( + messages=messages, + model="llm_deployment", + stream=True, +) +for chunk in stream: + print(chunk.choices[0].delta.content or "", end="") +``` + +The API requires an LLM deployment. Aana SDK provides support for [vLLM](/docs/integrations.md#vllm) and [Hugging Face Transformers](/docs/integrations.md#hugging-face-transformers). + +The name of the model matches the name of the deployment. For example, if you registered a vLLM deployment with the name `llm_deployment`, you can use it with the OpenAI API as `model="llm_deployment"`. + +```python +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +from aana.core.models.sampling import SamplingParams +from aana.core.models.types import Dtype +from aana.deployments.vllm_deployment import VLLMConfig, VLLMDeployment +from aana.sdk import AanaSDK + +llm_deployment = VLLMDeployment.options( + num_replicas=1, + ray_actor_options={"num_gpus": 1}, + user_config=VLLMConfig( + model="TheBloke/Llama-2-7b-Chat-AWQ", + dtype=Dtype.AUTO, + quantization="awq", + gpu_memory_reserved=13000, + enforce_eager=True, + default_sampling_params=SamplingParams( + temperature=0.0, top_p=1.0, top_k=-1, max_tokens=1024 + ), + chat_template="llama2", + ).model_dump(mode="json"), +) + +aana_app = AanaSDK(name="llm_app") +aana_app.register_deployment(name="llm_deployment", instance=llm_deployment) + +if __name__ == "__main__": + aana_app.connect() + aana_app.migrate() + aana_app.deploy() +``` + +You can also use the example project `llama2` to deploy Llama-2-7b Chat model. + +```bash +CUDA_VISIBLE_DEVICES=0 aana deploy aana.projects.llama2.app:aana_app +```