Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: migrate to new auth method #44

Merged
merged 10 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ AWS_ACCESS_KEY_ID=<key>
AWS_SECRET_ACCESS_KEY=<key>
DEFAULT_REGION=us-east-1

DIAL_USE_FILE_STORAGE=True
DIAL_URL=<dial core url>

# Misc env vars for the server
LOG_LEVEL=INFO # Default in prod is INFO. Use DEBUG for dev.
WEB_CONCURRENCY=1 # Number of unicorn workers
Expand Down
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ Supported models:
- anthropic.claude-v2
* Stable Diffusion
- stability.stable-diffusion-xl
* Meta
- meta.llama2-13b-chat-v1
- meta.llama2-70b-chat-v1
* Cohere
- cohere.command-text-v14
- cohere.command-light-text-v14

## Developer environment

Expand Down Expand Up @@ -64,9 +70,8 @@ Copy `.env.example` to `.env` and customize it for your environment:
|DEFAULT_REGION||AWS region e.g. "us-east-1"|
|LOG_LEVEL|INFO|Log level. Use DEBUG for dev purposes and INFO in prod|
|AIDIAL_LOG_LEVEL|WARNING|AI DIAL SDK log level|
|DIAL_USE_FILE_STORAGE|False|Save model artifacts to DIAL File storage (particularly, Stability images are uploaded to the files storage and their base64 encodings are replaced with links to the storage)|
|DIAL_USE_FILE_STORAGE|False|Save model artifacts to DIAL File storage (particularly, Stability images are uploaded to the files storage and their base64 encodings are replaced with links to the storage). The creds for the file storage must be passed in `Authorization` header of the incoming request. The file storage won't be used if the header isn't set.|
|DIAL_URL||URL of the core DIAL server (required when DIAL_USE_FILE_STORAGE=True)|
|DIAL_API_KEY||API Key for DIAL File storage (required when DIAL_USE_FILE_STORAGE=True)|
|WEB_CONCURRENCY|1|Number of workers for the server|
|TEST_SERVER_URL|http://0.0.0.0:5001|Server URL used in the integration tests|

Expand Down
1 change: 1 addition & 0 deletions aidial_adapter_bedrock/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ async def chat_completion(self, request: Request, response: Response):
model = await get_bedrock_adapter(
region=self.region,
model=request.deployment_id,
headers=request.headers,
)

async def generate_response(
Expand Down
21 changes: 21 additions & 0 deletions aidial_adapter_bedrock/dial_api/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Mapping, Optional

from pydantic import BaseModel


class Auth(BaseModel):
name: str
value: str

@property
def headers(self) -> dict[str, str]:
return {self.name: self.value}

@classmethod
def from_headers(
cls, name: str, headers: Mapping[str, str]
) -> Optional["Auth"]:
value = headers.get(name)
if value is None:
return None
return cls(name=name, value=value)
105 changes: 80 additions & 25 deletions aidial_adapter_bedrock/dial_api/storage.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,52 @@
import base64
import hashlib
import io
from typing import TypedDict
from typing import Mapping, Optional, TypedDict

import aiohttp

from aidial_adapter_bedrock.dial_api.auth import Auth
from aidial_adapter_bedrock.utils.env import get_env, get_env_bool
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


class FileMetadata(TypedDict):
name: str
parentPath: str
bucket: str
url: str
type: str
path: str

contentLength: int
contentType: str


class FileStorage:
base_url: str
api_key: str
dial_url: str
upload_dir: str
auth: Auth
bucket: Optional[str]

def __init__(self, dial_url: str, base_dir: str, api_key: str):
self.base_url = f"{dial_url}/v1/files/{base_dir}"
self.api_key = api_key
def __init__(self, dial_url: str, upload_dir: str, auth: Auth):
self.dial_url = dial_url
self.upload_dir = upload_dir
self.auth = auth
self.bucket = None

def auth_headers(self) -> dict[str, str]:
return {"api-key": self.api_key}
async def _get_bucket(self, session: aiohttp.ClientSession) -> str:
if self.bucket is None:
async with session.get(
f"{self.dial_url}/v1/bucket",
headers=self.auth.headers,
) as response:
response.raise_for_status()
body = await response.json()
self.bucket = body["bucket"]

return self.bucket

@staticmethod
def to_form_data(
def _to_form_data(
filename: str, content_type: str, content: bytes
) -> aiohttp.FormData:
data = aiohttp.FormData()
Expand All @@ -44,27 +62,64 @@ async def upload(
self, filename: str, content_type: str, content: bytes
) -> FileMetadata:
async with aiohttp.ClientSession() as session:
data = FileStorage.to_form_data(filename, content_type, content)
async with session.post(
self.base_url,
bucket = await self._get_bucket(session)
data = FileStorage._to_form_data(filename, content_type, content)
ext = _get_extension(content_type) or ""
url = f"{self.dial_url}/v1/files/{bucket}/{self.upload_dir}/{filename}{ext}"

async with session.put(
url=url,
data=data,
headers=self.auth_headers(),
headers=self.auth.headers,
) as response:
response.raise_for_status()
meta = await response.json()
log.debug(
f"Uploaded file: path={self.base_url}, file={filename}, metadata={meta}"
)
log.debug(f"Uploaded file: url={url}, metadata={meta}")
return meta

async def upload_file_as_base64(
self, data: str, content_type: str
) -> FileMetadata:
filename = _compute_hash_digest(data)

ext = _get_extension(content_type)
filename = f"{filename}.{ext}" if ext is not None else filename

content: bytes = base64.b64decode(data)
return await self.upload(filename, content_type, content)


def _hash_digest(string: str) -> str:
return hashlib.sha256(string.encode()).hexdigest()
def _compute_hash_digest(file_content: str) -> str:
return hashlib.sha256(file_content.encode()).hexdigest()


async def upload_base64_file(
storage: FileStorage, data: str, content_type: str
) -> FileMetadata:
filename = _hash_digest(data)
content: bytes = base64.b64decode(data)
return await storage.upload(filename, content_type, content)
def _get_extension(content_type: str) -> Optional[str]:
if content_type.startswith("image/"):
return content_type[len("image/") :]
return None


DIAL_USE_FILE_STORAGE = get_env_bool("DIAL_USE_FILE_STORAGE", False)

DIAL_URL: Optional[str] = None
if DIAL_USE_FILE_STORAGE:
DIAL_URL = get_env(
"DIAL_URL", "DIAL_URL must be set to use the DIAL file storage"
)


def create_file_storage(
base_dir: str, headers: Mapping[str, str]
) -> Optional[FileStorage]:
if not DIAL_USE_FILE_STORAGE or DIAL_URL is None:
vladisavvv marked this conversation as resolved.
Show resolved Hide resolved
return None

auth = Auth.from_headers("authorization", headers)
if auth is None:
log.debug(
"The request doesn't have required headers to use the DIAL file storage. "
vladisavvv marked this conversation as resolved.
Show resolved Hide resolved
"Fallback to base64 encoding of images."
)
return None

return FileStorage(dial_url=DIAL_URL, upload_dir=base_dir, auth=auth)
8 changes: 6 additions & 2 deletions aidial_adapter_bedrock/llm/model/adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Mapping

from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.llm.chat_emulator import default_emulator
from aidial_adapter_bedrock.llm.chat_model import ChatModel, Model
Expand All @@ -10,7 +12,9 @@
from aidial_adapter_bedrock.llm.tokenize import default_tokenize


async def get_bedrock_adapter(model: str, region: str) -> ChatModel:
async def get_bedrock_adapter(
model: str, region: str, headers: Mapping[str, str]
) -> ChatModel:
client = await Bedrock.acreate(region)
provider = Model.parse(model).provider
match provider:
Expand All @@ -21,7 +25,7 @@ async def get_bedrock_adapter(model: str, region: str) -> ChatModel:
client, model, default_tokenize, default_emulator
)
case "stability":
return StabilityAdapter(client, model)
return await StabilityAdapter.create(client, model, headers)
case "amazon":
return AmazonAdapter(
client, model, default_tokenize, default_emulator
Expand Down
41 changes: 17 additions & 24 deletions aidial_adapter_bedrock/llm/model/stability.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import os
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Mapping, Optional

from pydantic import BaseModel, Field

from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.storage import (
FileStorage,
upload_base64_file,
create_file_storage,
)
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_model import ChatModel, ChatPrompt
from aidial_adapter_bedrock.llm.consumer import Attachment, Consumer
from aidial_adapter_bedrock.llm.exceptions import ValidationError
from aidial_adapter_bedrock.llm.message import BaseMessage
from aidial_adapter_bedrock.utils.env import get_env


class StabilityStatus(str, Enum):
Expand Down Expand Up @@ -79,42 +77,37 @@ async def save_to_storage(
and attachment.type.startswith("image/")
and attachment.data is not None
):
response = await upload_base64_file(
storage, attachment.data, attachment.type
response = await storage.upload_file_as_base64(
attachment.data, attachment.type
)
return Attachment(
title=attachment.title,
type=attachment.type,
url=response["path"] + "/" + response["name"],
url=response["url"],
)

return attachment


DIAL_USE_FILE_STORAGE = (
os.getenv("DIAL_USE_FILE_STORAGE", "false").lower() == "true"
)

if DIAL_USE_FILE_STORAGE:
DIAL_URL = get_env("DIAL_URL")
DIAL_API_KEY = get_env("DIAL_API_KEY")


class StabilityAdapter(ChatModel):
client: Bedrock
storage: Optional[FileStorage]

def __init__(self, client: Bedrock, model: str):
def __init__(
self, client: Bedrock, model: str, storage: Optional[FileStorage]
):
super().__init__(model)
self.client = client
self.storage = None
self.storage = storage

if DIAL_USE_FILE_STORAGE:
self.storage = FileStorage(
dial_url=DIAL_URL,
api_key=DIAL_API_KEY,
base_dir="stability",
)
@classmethod
async def create(
cls, client: Bedrock, model: str, headers: Mapping[str, str]
):
storage: Optional[FileStorage] = create_file_storage(
"images/stable-diffusion", headers
)
return cls(client, model, storage)

def _prepare_prompt(
self, messages: List[BaseMessage], max_prompt_tokens: Optional[int]
Expand Down
9 changes: 7 additions & 2 deletions aidial_adapter_bedrock/utils/env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import os
from typing import Optional


def get_env(name: str) -> str:
def get_env(name: str, err_msg: Optional[str] = None) -> str:
if name in os.environ:
val = os.environ.get(name)
if val is not None:
return val

raise Exception(f"{name} env variable is not set")
raise Exception(err_msg or f"{name} env variable is not set")


def get_env_bool(name: str, default: bool = False) -> bool:
return os.getenv(name, str(default)).lower() == "true"
1 change: 1 addition & 0 deletions client/client_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def main():
model = await get_bedrock_adapter(
model=deployment.get_model_id(),
region=location,
headers={},
)

messages: List[Message] = []
Expand Down
Loading