diff --git a/Dockerfile b/Dockerfile index 9b7502d..0365070 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,6 +27,7 @@ COPY ./scripts/docker_entrypoint.sh /docker_entrypoint.sh RUN chmod +x /docker_entrypoint.sh ENV LOG_LEVEL=INFO +ENV USE_DIAL_FILE_STORAGE=True EXPOSE 5000 USER appuser diff --git a/README.md b/README.md index 6556680..57dcbd0 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,9 @@ Copy `.env.example` to `.env` and customize it for your environment: |DEFAULT_REGION||AWS region e.g. "us-east-1"| |LOG_LEVEL|INFO|Log level. Use DEBUG for dev purposes and INFO in prod| |AIDIAL_LOG_LEVEL|WARNING|AI DIAL SDK log level| +|USE_DIAL_FILE_STORAGE|False|Save model artifacts to DIAL File storage (particularly, Stability images are uploaded to the files storage and their base64 encodings are replaced with links to the storage)| +|DIAL_URL||URL of the core DIAL server| +|DIAL_BEDROCK_API_KEY||API Key for DIAL File storage| |WEB_CONCURRENCY|1|Number of workers for the server| |TEST_SERVER_URL|http://0.0.0.0:5001|Server URL used in the integration tests| diff --git a/aidial_adapter_bedrock/app.py b/aidial_adapter_bedrock/app.py index aea0b7d..6c5627e 100644 --- a/aidial_adapter_bedrock/app.py +++ b/aidial_adapter_bedrock/app.py @@ -7,13 +7,10 @@ from fastapi.responses import JSONResponse from aidial_adapter_bedrock.chat_completion import BedrockChatCompletion +from aidial_adapter_bedrock.dial_api.response import ModelObject, ModelsResponse from aidial_adapter_bedrock.llm.bedrock_models import BedrockDeployment from aidial_adapter_bedrock.llm.model_listing import get_bedrock_models from aidial_adapter_bedrock.server.exceptions import dial_exception_decorator -from aidial_adapter_bedrock.universal_api.response import ( - ModelObject, - ModelsResponse, -) from aidial_adapter_bedrock.utils.env import get_env from aidial_adapter_bedrock.utils.log_config import LogConfig from aidial_adapter_bedrock.utils.log_config import app_logger as log @@ -22,7 +19,7 @@ default_region = get_env("DEFAULT_REGION") -app = DIALApp(description="AWS Bedrock adapter for RAIL API") +app = DIALApp(description="AWS Bedrock adapter for DIAL API") @app.get("/healthcheck") diff --git a/aidial_adapter_bedrock/chat_completion.py b/aidial_adapter_bedrock/chat_completion.py index 182a05d..acd9b7b 100644 --- a/aidial_adapter_bedrock/chat_completion.py +++ b/aidial_adapter_bedrock/chat_completion.py @@ -3,11 +3,11 @@ from aidial_sdk.chat_completion import ChatCompletion, Request, Response +from aidial_adapter_bedrock.dial_api.request import ModelParameters +from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage from aidial_adapter_bedrock.llm.consumer import ChoiceConsumer from aidial_adapter_bedrock.llm.model.adapter import get_bedrock_adapter from aidial_adapter_bedrock.server.exceptions import dial_exception_decorator -from aidial_adapter_bedrock.universal_api.request import ModelParameters -from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage from aidial_adapter_bedrock.utils.log_config import app_logger as log diff --git a/aidial_adapter_bedrock/universal_api/__init__.py b/aidial_adapter_bedrock/dial_api/__init__.py similarity index 100% rename from aidial_adapter_bedrock/universal_api/__init__.py rename to aidial_adapter_bedrock/dial_api/__init__.py diff --git a/aidial_adapter_bedrock/universal_api/request.py b/aidial_adapter_bedrock/dial_api/request.py similarity index 100% rename from aidial_adapter_bedrock/universal_api/request.py rename to aidial_adapter_bedrock/dial_api/request.py diff --git a/aidial_adapter_bedrock/universal_api/response.py b/aidial_adapter_bedrock/dial_api/response.py similarity index 100% rename from aidial_adapter_bedrock/universal_api/response.py rename to aidial_adapter_bedrock/dial_api/response.py diff --git a/aidial_adapter_bedrock/dial_api/storage.py b/aidial_adapter_bedrock/dial_api/storage.py new file mode 100644 index 0000000..6b03597 --- /dev/null +++ b/aidial_adapter_bedrock/dial_api/storage.py @@ -0,0 +1,70 @@ +import base64 +import hashlib +import io +from typing import TypedDict + +import aiohttp + +from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log + + +class FileMetadata(TypedDict): + name: str + type: str + path: str + contentLength: int + contentType: str + + +class FileStorage: + base_url: str + api_key: str + + def __init__(self, dial_url: str, base_dir: str, api_key: str): + self.base_url = f"{dial_url}/v1/files/{base_dir}" + self.api_key = api_key + + def auth_headers(self) -> dict[str, str]: + return {"api-key": self.api_key} + + @staticmethod + def to_form_data( + filename: str, content_type: str, content: bytes + ) -> aiohttp.FormData: + data = aiohttp.FormData() + data.add_field( + "file", + io.BytesIO(content), + filename=filename, + content_type=content_type, + ) + return data + + async def upload( + self, filename: str, content_type: str, content: bytes + ) -> FileMetadata: + async with aiohttp.ClientSession() as session: + data = FileStorage.to_form_data(filename, content_type, content) + async with session.post( + self.base_url, + data=data, + headers=self.auth_headers(), + ) as response: + response.raise_for_status() + meta = await response.json() + log.debug( + f"Uploaded file: path={self.base_url}, file={filename}, metadata={meta}" + ) + return meta + + +def _hash_digest(string: str) -> str: + return hashlib.sha256(string.encode()).hexdigest() + + +async def upload_base64_file( + storage: FileStorage, data: str, content_type: str +) -> FileMetadata: + filename = _hash_digest(data) + content: bytes = base64.b64decode(data) + return await storage.upload(filename, content_type, content) diff --git a/aidial_adapter_bedrock/universal_api/token_usage.py b/aidial_adapter_bedrock/dial_api/token_usage.py similarity index 100% rename from aidial_adapter_bedrock/universal_api/token_usage.py rename to aidial_adapter_bedrock/dial_api/token_usage.py diff --git a/aidial_adapter_bedrock/llm/chat_model.py b/aidial_adapter_bedrock/llm/chat_model.py index b5fdb35..88ec9b1 100644 --- a/aidial_adapter_bedrock/llm/chat_model.py +++ b/aidial_adapter_bedrock/llm/chat_model.py @@ -4,6 +4,7 @@ from aidial_sdk.chat_completion import Message from pydantic import BaseModel +from aidial_adapter_bedrock.dial_api.request import ModelParameters from aidial_adapter_bedrock.llm.chat_emulation.pseudo_chat import ( PseudoChatHistory, ) @@ -14,7 +15,6 @@ SystemMessage, parse_message, ) -from aidial_adapter_bedrock.universal_api.request import ModelParameters from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log diff --git a/aidial_adapter_bedrock/llm/consumer.py b/aidial_adapter_bedrock/llm/consumer.py index 2d5f0f9..09a5486 100644 --- a/aidial_adapter_bedrock/llm/consumer.py +++ b/aidial_adapter_bedrock/llm/consumer.py @@ -4,7 +4,7 @@ from aidial_sdk.chat_completion import Choice from pydantic import BaseModel -from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage +from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage class Attachment(BaseModel): diff --git a/aidial_adapter_bedrock/llm/model/ai21.py b/aidial_adapter_bedrock/llm/model/ai21.py index 1b536af..e9f6040 100644 --- a/aidial_adapter_bedrock/llm/model/ai21.py +++ b/aidial_adapter_bedrock/llm/model/ai21.py @@ -3,11 +3,11 @@ from pydantic import BaseModel +from aidial_adapter_bedrock.dial_api.request import ModelParameters +from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage from aidial_adapter_bedrock.llm.chat_model import PseudoChatModel from aidial_adapter_bedrock.llm.consumer import Consumer from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_AI21 -from aidial_adapter_bedrock.universal_api.request import ModelParameters -from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage from aidial_adapter_bedrock.utils.concurrency import make_async diff --git a/aidial_adapter_bedrock/llm/model/amazon.py b/aidial_adapter_bedrock/llm/model/amazon.py index c8ffc8b..993c773 100644 --- a/aidial_adapter_bedrock/llm/model/amazon.py +++ b/aidial_adapter_bedrock/llm/model/amazon.py @@ -5,13 +5,13 @@ from typing_extensions import override import aidial_adapter_bedrock.utils.stream as stream +from aidial_adapter_bedrock.dial_api.request import ModelParameters +from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage from aidial_adapter_bedrock.llm.chat_emulation.pseudo_chat import RolePrompt from aidial_adapter_bedrock.llm.chat_model import PseudoChatModel from aidial_adapter_bedrock.llm.consumer import Consumer from aidial_adapter_bedrock.llm.message import BaseMessage from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_AMAZON -from aidial_adapter_bedrock.universal_api.request import ModelParameters -from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage from aidial_adapter_bedrock.utils.concurrency import make_async from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log diff --git a/aidial_adapter_bedrock/llm/model/anthropic.py b/aidial_adapter_bedrock/llm/model/anthropic.py index 1327489..db87471 100644 --- a/aidial_adapter_bedrock/llm/model/anthropic.py +++ b/aidial_adapter_bedrock/llm/model/anthropic.py @@ -3,6 +3,8 @@ from anthropic.tokenizer import count_tokens +from aidial_adapter_bedrock.dial_api.request import ModelParameters +from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage from aidial_adapter_bedrock.llm.chat_emulation import claude_chat from aidial_adapter_bedrock.llm.chat_emulation.claude_chat import ( ClaudeChatHistory, @@ -11,8 +13,6 @@ from aidial_adapter_bedrock.llm.consumer import Consumer from aidial_adapter_bedrock.llm.message import BaseMessage from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_ANTHROPIC -from aidial_adapter_bedrock.universal_api.request import ModelParameters -from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage from aidial_adapter_bedrock.utils.concurrency import make_async from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log diff --git a/aidial_adapter_bedrock/llm/model/stability.py b/aidial_adapter_bedrock/llm/model/stability.py index 242af57..1183d1c 100644 --- a/aidial_adapter_bedrock/llm/model/stability.py +++ b/aidial_adapter_bedrock/llm/model/stability.py @@ -1,24 +1,24 @@ import json +import os from enum import Enum from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field +from aidial_adapter_bedrock.dial_api.request import ModelParameters +from aidial_adapter_bedrock.dial_api.storage import ( + FileStorage, + upload_base64_file, +) +from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage from aidial_adapter_bedrock.llm.chat_emulation.zero_memory_chat import ( ZeroMemoryChatHistory, ) from aidial_adapter_bedrock.llm.chat_model import ChatModel, ChatPrompt from aidial_adapter_bedrock.llm.consumer import Attachment, Consumer from aidial_adapter_bedrock.llm.message import BaseMessage -from aidial_adapter_bedrock.universal_api.request import ModelParameters -from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage from aidial_adapter_bedrock.utils.concurrency import make_async - - -class ResponseData(BaseModel): - mime_type: str - name: str - content: str +from aidial_adapter_bedrock.utils.env import get_env class StabilityStatus(str, Enum): @@ -40,7 +40,7 @@ class StabilityArtifact(BaseModel): class StabilityResponse(BaseModel): # TODO: Use tagged union artifacts/error - result: str + result: StabilityStatus artifacts: Optional[list[StabilityArtifact]] error: Optional[StabilityError] @@ -48,17 +48,18 @@ def content(self) -> str: self._throw_if_error() return "" - def data(self) -> list[ResponseData]: + def attachments(self) -> list[Attachment]: self._throw_if_error() return [ - ResponseData( - mime_type="image/png", - name="image", - content=self.artifacts[0].base64, # type: ignore + Attachment( + title="image", + type="image/png", + data=self.artifacts[0].base64, # type: ignore ) ] def usage(self) -> TokenUsage: + self._throw_if_error() return TokenUsage( prompt_tokens=0, completion_tokens=1, @@ -73,14 +74,50 @@ def prepare_input(prompt: str) -> Dict[str, Any]: return {"text_prompts": [{"text": prompt}]} -class StabilityAdapter(ChatModel): - def __init__( - self, - bedrock: Any, - model_id: str, +async def save_to_storage( + storage: FileStorage, attachment: Attachment +) -> Attachment: + if ( + attachment.type is not None + and attachment.type.startswith("image/") + and attachment.data is not None ): + response = await upload_base64_file( + storage, attachment.data, attachment.type + ) + return Attachment( + title=attachment.title, + type=attachment.type, + url=response["path"] + "/" + response["name"], + ) + + return attachment + + +USE_DIAL_FILE_STORAGE = ( + os.getenv("USE_DIAL_FILE_STORAGE", "false").lower() == "true" +) + +if USE_DIAL_FILE_STORAGE: + DIAL_URL = get_env("DIAL_URL") + DIAL_BEDROCK_API_KEY = get_env("DIAL_BEDROCK_API_KEY") + + +class StabilityAdapter(ChatModel): + bedrock: Any + storage: Optional[FileStorage] + + def __init__(self, bedrock: Any, model_id: str): super().__init__(model_id) self.bedrock = bedrock + self.storage = None + + if USE_DIAL_FILE_STORAGE: + self.storage = FileStorage( + dial_url=DIAL_URL, + api_key=DIAL_BEDROCK_API_KEY, + base_dir="stability", + ) def _prepare_prompt( self, messages: List[BaseMessage], max_prompt_tokens: Optional[int] @@ -95,16 +132,14 @@ def _prepare_prompt( async def _apredict( self, consumer: Consumer, model_params: ModelParameters, prompt: str ): - return await make_async( - lambda args: self._call(*args), (consumer, prompt) - ) - - def _call(self, consumer: Consumer, prompt: str): - model_response = self.bedrock.invoke_model( - modelId=self.model_id, - accept="application/json", - contentType="application/json", - body=json.dumps(prepare_input(prompt)), + model_response = await make_async( + lambda args: self.bedrock.invoke_model( + accept="application/json", + contentType="application/json", + modelId=args[0], + body=args[1], + ), + (self.model_id, json.dumps(prepare_input(prompt))), ) body = json.loads(model_response["body"].read()) @@ -113,11 +148,7 @@ def _call(self, consumer: Consumer, prompt: str): consumer.append_content(resp.content()) consumer.add_usage(resp.usage()) - for data in resp.data(): - consumer.add_attachment( - Attachment( - title=data.name, - data=data.content, - type=data.mime_type, - ) - ) + for attachment in resp.attachments(): + if self.storage is not None: + attachment = await save_to_storage(self.storage, attachment) + consumer.add_attachment(attachment) diff --git a/client/client_bedrock.py b/client/client_bedrock.py index eba326e..8849389 100755 --- a/client/client_bedrock.py +++ b/client/client_bedrock.py @@ -3,10 +3,10 @@ from aidial_sdk.chat_completion import Message, Role +from aidial_adapter_bedrock.dial_api.request import ModelParameters from aidial_adapter_bedrock.llm.bedrock_models import BedrockDeployment from aidial_adapter_bedrock.llm.consumer import CollectConsumer from aidial_adapter_bedrock.llm.model.adapter import get_bedrock_adapter -from aidial_adapter_bedrock.universal_api.request import ModelParameters from aidial_adapter_bedrock.utils.env import get_env from aidial_adapter_bedrock.utils.printing import print_ai, print_info from client.conf import MAX_CHAT_TURN, MAX_INPUT_CHARS