Skip to content

Commit

Permalink
feat: save stability artifacts to DIAL file storage (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Dec 1, 2023
1 parent e27a1d1 commit 590b054
Show file tree
Hide file tree
Showing 16 changed files with 155 additions and 53 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ COPY ./scripts/docker_entrypoint.sh /docker_entrypoint.sh
RUN chmod +x /docker_entrypoint.sh

ENV LOG_LEVEL=INFO
ENV USE_DIAL_FILE_STORAGE=True
EXPOSE 5000

USER appuser
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ Copy `.env.example` to `.env` and customize it for your environment:
|DEFAULT_REGION||AWS region e.g. "us-east-1"|
|LOG_LEVEL|INFO|Log level. Use DEBUG for dev purposes and INFO in prod|
|AIDIAL_LOG_LEVEL|WARNING|AI DIAL SDK log level|
|USE_DIAL_FILE_STORAGE|False|Save model artifacts to DIAL File storage (particularly, Stability images are uploaded to the files storage and their base64 encodings are replaced with links to the storage)|
|DIAL_URL||URL of the core DIAL server|
|DIAL_BEDROCK_API_KEY||API Key for DIAL File storage|
|WEB_CONCURRENCY|1|Number of workers for the server|
|TEST_SERVER_URL|http://0.0.0.0:5001|Server URL used in the integration tests|

Expand Down
7 changes: 2 additions & 5 deletions aidial_adapter_bedrock/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@
from fastapi.responses import JSONResponse

from aidial_adapter_bedrock.chat_completion import BedrockChatCompletion
from aidial_adapter_bedrock.dial_api.response import ModelObject, ModelsResponse
from aidial_adapter_bedrock.llm.bedrock_models import BedrockDeployment
from aidial_adapter_bedrock.llm.model_listing import get_bedrock_models
from aidial_adapter_bedrock.server.exceptions import dial_exception_decorator
from aidial_adapter_bedrock.universal_api.response import (
ModelObject,
ModelsResponse,
)
from aidial_adapter_bedrock.utils.env import get_env
from aidial_adapter_bedrock.utils.log_config import LogConfig
from aidial_adapter_bedrock.utils.log_config import app_logger as log
Expand All @@ -22,7 +19,7 @@

default_region = get_env("DEFAULT_REGION")

app = DIALApp(description="AWS Bedrock adapter for RAIL API")
app = DIALApp(description="AWS Bedrock adapter for DIAL API")


@app.get("/healthcheck")
Expand Down
4 changes: 2 additions & 2 deletions aidial_adapter_bedrock/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from aidial_sdk.chat_completion import ChatCompletion, Request, Response

from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.consumer import ChoiceConsumer
from aidial_adapter_bedrock.llm.model.adapter import get_bedrock_adapter
from aidial_adapter_bedrock.server.exceptions import dial_exception_decorator
from aidial_adapter_bedrock.universal_api.request import ModelParameters
from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage
from aidial_adapter_bedrock.utils.log_config import app_logger as log


Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
70 changes: 70 additions & 0 deletions aidial_adapter_bedrock/dial_api/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import base64
import hashlib
import io
from typing import TypedDict

import aiohttp

from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


class FileMetadata(TypedDict):
name: str
type: str
path: str
contentLength: int
contentType: str


class FileStorage:
base_url: str
api_key: str

def __init__(self, dial_url: str, base_dir: str, api_key: str):
self.base_url = f"{dial_url}/v1/files/{base_dir}"
self.api_key = api_key

def auth_headers(self) -> dict[str, str]:
return {"api-key": self.api_key}

@staticmethod
def to_form_data(
filename: str, content_type: str, content: bytes
) -> aiohttp.FormData:
data = aiohttp.FormData()
data.add_field(
"file",
io.BytesIO(content),
filename=filename,
content_type=content_type,
)
return data

async def upload(
self, filename: str, content_type: str, content: bytes
) -> FileMetadata:
async with aiohttp.ClientSession() as session:
data = FileStorage.to_form_data(filename, content_type, content)
async with session.post(
self.base_url,
data=data,
headers=self.auth_headers(),
) as response:
response.raise_for_status()
meta = await response.json()
log.debug(
f"Uploaded file: path={self.base_url}, file={filename}, metadata={meta}"
)
return meta


def _hash_digest(string: str) -> str:
return hashlib.sha256(string.encode()).hexdigest()


async def upload_base64_file(
storage: FileStorage, data: str, content_type: str
) -> FileMetadata:
filename = _hash_digest(data)
content: bytes = base64.b64decode(data)
return await storage.upload(filename, content_type, content)
2 changes: 1 addition & 1 deletion aidial_adapter_bedrock/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from aidial_sdk.chat_completion import Message
from pydantic import BaseModel

from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.llm.chat_emulation.pseudo_chat import (
PseudoChatHistory,
)
Expand All @@ -14,7 +15,6 @@
SystemMessage,
parse_message,
)
from aidial_adapter_bedrock.universal_api.request import ModelParameters
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


Expand Down
2 changes: 1 addition & 1 deletion aidial_adapter_bedrock/llm/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from aidial_sdk.chat_completion import Choice
from pydantic import BaseModel

from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage


class Attachment(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions aidial_adapter_bedrock/llm/model/ai21.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from pydantic import BaseModel

from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_model import PseudoChatModel
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_AI21
from aidial_adapter_bedrock.universal_api.request import ModelParameters
from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage
from aidial_adapter_bedrock.utils.concurrency import make_async


Expand Down
4 changes: 2 additions & 2 deletions aidial_adapter_bedrock/llm/model/amazon.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from typing_extensions import override

import aidial_adapter_bedrock.utils.stream as stream
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_emulation.pseudo_chat import RolePrompt
from aidial_adapter_bedrock.llm.chat_model import PseudoChatModel
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.message import BaseMessage
from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_AMAZON
from aidial_adapter_bedrock.universal_api.request import ModelParameters
from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage
from aidial_adapter_bedrock.utils.concurrency import make_async
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log

Expand Down
4 changes: 2 additions & 2 deletions aidial_adapter_bedrock/llm/model/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from anthropic.tokenizer import count_tokens

from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_emulation import claude_chat
from aidial_adapter_bedrock.llm.chat_emulation.claude_chat import (
ClaudeChatHistory,
Expand All @@ -11,8 +13,6 @@
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.message import BaseMessage
from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_ANTHROPIC
from aidial_adapter_bedrock.universal_api.request import ModelParameters
from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage
from aidial_adapter_bedrock.utils.concurrency import make_async
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log

Expand Down
105 changes: 68 additions & 37 deletions aidial_adapter_bedrock/llm/model/stability.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import json
import os
from enum import Enum
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field

from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.storage import (
FileStorage,
upload_base64_file,
)
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_emulation.zero_memory_chat import (
ZeroMemoryChatHistory,
)
from aidial_adapter_bedrock.llm.chat_model import ChatModel, ChatPrompt
from aidial_adapter_bedrock.llm.consumer import Attachment, Consumer
from aidial_adapter_bedrock.llm.message import BaseMessage
from aidial_adapter_bedrock.universal_api.request import ModelParameters
from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage
from aidial_adapter_bedrock.utils.concurrency import make_async


class ResponseData(BaseModel):
mime_type: str
name: str
content: str
from aidial_adapter_bedrock.utils.env import get_env


class StabilityStatus(str, Enum):
Expand All @@ -40,25 +40,26 @@ class StabilityArtifact(BaseModel):

class StabilityResponse(BaseModel):
# TODO: Use tagged union artifacts/error
result: str
result: StabilityStatus
artifacts: Optional[list[StabilityArtifact]]
error: Optional[StabilityError]

def content(self) -> str:
self._throw_if_error()
return ""

def data(self) -> list[ResponseData]:
def attachments(self) -> list[Attachment]:
self._throw_if_error()
return [
ResponseData(
mime_type="image/png",
name="image",
content=self.artifacts[0].base64, # type: ignore
Attachment(
title="image",
type="image/png",
data=self.artifacts[0].base64, # type: ignore
)
]

def usage(self) -> TokenUsage:
self._throw_if_error()
return TokenUsage(
prompt_tokens=0,
completion_tokens=1,
Expand All @@ -73,14 +74,50 @@ def prepare_input(prompt: str) -> Dict[str, Any]:
return {"text_prompts": [{"text": prompt}]}


class StabilityAdapter(ChatModel):
def __init__(
self,
bedrock: Any,
model_id: str,
async def save_to_storage(
storage: FileStorage, attachment: Attachment
) -> Attachment:
if (
attachment.type is not None
and attachment.type.startswith("image/")
and attachment.data is not None
):
response = await upload_base64_file(
storage, attachment.data, attachment.type
)
return Attachment(
title=attachment.title,
type=attachment.type,
url=response["path"] + "/" + response["name"],
)

return attachment


USE_DIAL_FILE_STORAGE = (
os.getenv("USE_DIAL_FILE_STORAGE", "false").lower() == "true"
)

if USE_DIAL_FILE_STORAGE:
DIAL_URL = get_env("DIAL_URL")
DIAL_BEDROCK_API_KEY = get_env("DIAL_BEDROCK_API_KEY")


class StabilityAdapter(ChatModel):
bedrock: Any
storage: Optional[FileStorage]

def __init__(self, bedrock: Any, model_id: str):
super().__init__(model_id)
self.bedrock = bedrock
self.storage = None

if USE_DIAL_FILE_STORAGE:
self.storage = FileStorage(
dial_url=DIAL_URL,
api_key=DIAL_BEDROCK_API_KEY,
base_dir="stability",
)

def _prepare_prompt(
self, messages: List[BaseMessage], max_prompt_tokens: Optional[int]
Expand All @@ -95,16 +132,14 @@ def _prepare_prompt(
async def _apredict(
self, consumer: Consumer, model_params: ModelParameters, prompt: str
):
return await make_async(
lambda args: self._call(*args), (consumer, prompt)
)

def _call(self, consumer: Consumer, prompt: str):
model_response = self.bedrock.invoke_model(
modelId=self.model_id,
accept="application/json",
contentType="application/json",
body=json.dumps(prepare_input(prompt)),
model_response = await make_async(
lambda args: self.bedrock.invoke_model(
accept="application/json",
contentType="application/json",
modelId=args[0],
body=args[1],
),
(self.model_id, json.dumps(prepare_input(prompt))),
)

body = json.loads(model_response["body"].read())
Expand All @@ -113,11 +148,7 @@ def _call(self, consumer: Consumer, prompt: str):
consumer.append_content(resp.content())
consumer.add_usage(resp.usage())

for data in resp.data():
consumer.add_attachment(
Attachment(
title=data.name,
data=data.content,
type=data.mime_type,
)
)
for attachment in resp.attachments():
if self.storage is not None:
attachment = await save_to_storage(self.storage, attachment)
consumer.add_attachment(attachment)
2 changes: 1 addition & 1 deletion client/client_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from aidial_sdk.chat_completion import Message, Role

from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.llm.bedrock_models import BedrockDeployment
from aidial_adapter_bedrock.llm.consumer import CollectConsumer
from aidial_adapter_bedrock.llm.model.adapter import get_bedrock_adapter
from aidial_adapter_bedrock.universal_api.request import ModelParameters
from aidial_adapter_bedrock.utils.env import get_env
from aidial_adapter_bedrock.utils.printing import print_ai, print_info
from client.conf import MAX_CHAT_TURN, MAX_INPUT_CHARS
Expand Down

0 comments on commit 590b054

Please sign in to comment.