Skip to content

Commit

Permalink
feat: migrated to the latest DIAL SDK (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Jul 26, 2024
1 parent d07201e commit fc1f4ae
Show file tree
Hide file tree
Showing 31 changed files with 1,032 additions and 444 deletions.
5 changes: 4 additions & 1 deletion .ort.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,7 @@ resolutions:
comment: 'BSD 3-Clause "New" or "Revised" License: https://github.com/encode/httpcore/blob/0.18.0/LICENSE.md'
- message: ".*PyPI::werkzeug:3\\.0\\.1.*"
reason: 'CANT_FIX_EXCEPTION'
comment: 'BSD 3-Clause "New" or "Revised" License: https://github.com/pallets/werkzeug/blob/3.0.1/LICENSE.rst'
comment: 'BSD 3-Clause "New" or "Revised" License: https://github.com/pallets/werkzeug/blob/3.0.1/LICENSE.rst'
- message: ".*PyPI::numpy:2\\.0\\.0.*"
reason: 'CANT_FIX_EXCEPTION'
comment: 'BSD License: https://github.com/numpy/numpy/blob/v2.0.0/LICENSE.txt'
44 changes: 6 additions & 38 deletions aidial_adapter_bedrock/app.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,15 @@
import json
from typing import Optional

from aidial_sdk import DIALApp
from aidial_sdk.telemetry.types import TelemetryConfig
from fastapi import Body, Header, Path

from aidial_adapter_bedrock.chat_completion import BedrockChatCompletion
from aidial_adapter_bedrock.deployments import (
ChatCompletionDeployment,
EmbeddingsDeployment,
)
from aidial_adapter_bedrock.dial_api.request import (
EmbeddingsRequest,
EmbeddingsType,
)
from aidial_adapter_bedrock.dial_api.response import (
ModelObject,
ModelsResponse,
make_embeddings_response,
)
from aidial_adapter_bedrock.llm.model.adapter import get_embeddings_model
from aidial_adapter_bedrock.dial_api.response import ModelObject, ModelsResponse
from aidial_adapter_bedrock.embeddings import BedrockEmbeddings
from aidial_adapter_bedrock.server.exceptions import dial_exception_decorator
from aidial_adapter_bedrock.utils.env import get_aws_default_region
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log
from aidial_adapter_bedrock.utils.log_config import configure_loggers

AWS_DEFAULT_REGION = get_aws_default_region()
Expand Down Expand Up @@ -56,27 +43,8 @@ async def models():
BedrockChatCompletion(region=AWS_DEFAULT_REGION),
)


@app.post("/openai/deployments/{deployment}/embeddings")
@dial_exception_decorator
async def embeddings(
embeddings_type: EmbeddingsType = Header(
alias="X-DIAL-Type", default=EmbeddingsType.SYMMETRIC
),
embeddings_instruction: Optional[str] = Header(
alias="X-DIAL-Instruction", default=None
),
deployment: EmbeddingsDeployment = Path(...),
request: dict = Body(..., examples=[EmbeddingsRequest.example()]),
):
log.debug(f"request: {json.dumps(request)}")

model = await get_embeddings_model(
deployment=deployment, region=AWS_DEFAULT_REGION
)

response = await model.embeddings(
request, embeddings_instruction, embeddings_type
for deployment in EmbeddingsDeployment:
app.add_embeddings(
deployment.deployment_id,
BedrockEmbeddings(region=AWS_DEFAULT_REGION),
)

return make_embeddings_response(deployment, response)
25 changes: 21 additions & 4 deletions aidial_adapter_bedrock/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from abc import ABC
from logging import DEBUG
from typing import Any, AsyncIterator, Optional

import boto3
Expand All @@ -12,6 +13,7 @@
make_async,
to_async_iterator,
)
from aidial_adapter_bedrock.utils.json import json_dumps_short
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


Expand All @@ -37,35 +39,50 @@ def _create_invoke_params(self, model: str, body: dict) -> dict:
}

async def ainvoke_non_streaming(self, model: str, args: dict) -> dict:

if log.isEnabledFor(DEBUG):
log.debug(
f"request: {json_dumps_short({'model': model, 'args': args})}"
)

params = self._create_invoke_params(model, args)
response = await make_async(lambda: self.client.invoke_model(**params))

log.debug(f"response: {response}")
if log.isEnabledFor(DEBUG):
log.debug(f"response: {json_dumps_short(response)}")

body: StreamingBody = response["body"]
body_dict = json.loads(await make_async(lambda: body.read()))

log.debug(f"response['body']: {json.dumps(body_dict)}")
if log.isEnabledFor(DEBUG):
log.debug(f"response['body']: {json_dumps_short(body_dict)}")

return body_dict

async def ainvoke_streaming(
self, model: str, args: dict
) -> AsyncIterator[dict]:
if log.isEnabledFor(DEBUG):
log.debug(
f"request: {json_dumps_short({'model': model, 'args': args})}"
)

params = self._create_invoke_params(model, args)
response = await make_async(
lambda: self.client.invoke_model_with_response_stream(**params)
)

log.debug(f"response: {response}")
if log.isEnabledFor(DEBUG):
log.debug(f"response: {json_dumps_short(response)}")

body: EventStream = response["body"]

async for event in to_async_iterator(iter(body)):
chunk = event.get("chunk")
if chunk:
chunk_dict = json.loads(chunk.get("bytes").decode())
log.debug(f"chunk: {json.dumps(chunk_dict)}")
if log.isEnabledFor(DEBUG):
log.debug(f"chunk: {json_dumps_short(chunk_dict)}")
yield chunk_dict


Expand Down
2 changes: 1 addition & 1 deletion aidial_adapter_bedrock/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def _get_model(
return await get_bedrock_adapter(
region=self.region,
deployment=deployment,
headers=request.headers,
api_key=request.api_key,
)

@dial_exception_decorator
Expand Down
21 changes: 0 additions & 21 deletions aidial_adapter_bedrock/dial_api/auth.py

This file was deleted.

113 changes: 113 additions & 0 deletions aidial_adapter_bedrock/dial_api/embedding_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
List,
TypeVar,
assert_never,
cast,
)

from aidial_sdk.chat_completion.request import Attachment
from aidial_sdk.embeddings.request import EmbeddingsRequest

from aidial_adapter_bedrock.llm.errors import ValidationError

T = TypeVar("T")

Coro = Coroutine[T, Any, Any]
Tokens = List[int]


async def reject_tokens(tokens: Tokens):
raise ValidationError(
"Tokens in the input are not supported, provide text instead. "
"When Langchain AzureOpenAIEmbeddings class is used, set 'check_embedding_ctx_length=False' to disable tokenization."
)


EMPTY_INPUT_LIST_ERROR = ValidationError(
"Empty list in an element of custom_input list"
)

ATTACHMENT_ERROR = ValidationError("Attachments are not supported")


async def reject_attachment(attachment: Attachment):
raise ATTACHMENT_ERROR


async def collect_embedding_inputs(
request: EmbeddingsRequest,
*,
on_text: Callable[[str], Coro[T]],
on_tokens: Callable[[Tokens], Coro[T]] = reject_tokens,
on_attachment: Callable[[Attachment], Coro[T]] = reject_attachment,
on_mixed: Callable[[List[str | Attachment]], Coro[T]],
) -> AsyncIterator[T]:

if isinstance(request.input, str):
yield await on_text(request.input)
elif isinstance(request.input, list):

is_list_of_tokens = False
for input in request.input:
if isinstance(input, str):
yield await on_text(input)
elif isinstance(input, list):
yield await on_tokens(input)
else:
is_list_of_tokens = True
break

if is_list_of_tokens:
yield await on_tokens(cast(Tokens, request.input))

else:
assert_never(request.input)

if request.custom_input is None:
return

for input in request.custom_input:
if isinstance(input, str):
yield await on_text(input)
elif isinstance(input, Attachment):
yield await on_attachment(input)
elif isinstance(input, list):
yield await on_mixed(input)
else:
assert_never(input)


def collect_embedding_inputs_without_attachments(
request: EmbeddingsRequest,
*,
on_texts: Callable[[List[str]], Coro[T]],
on_tokens: Callable[[Tokens], Coro[T]] = reject_tokens,
) -> AsyncIterator[T]:

async def on_text(text: str) -> Coro[T]:
return await on_texts([text])

async def on_mixed(inputs: List[str | Attachment]) -> Coro[T]:
if inputs == []:
raise EMPTY_INPUT_LIST_ERROR

texts: List[str] = []
for input in inputs:
if isinstance(input, str):
texts.append(input)
else:
raise ATTACHMENT_ERROR

return await on_texts(texts)

return collect_embedding_inputs(
request,
on_text=on_text,
on_tokens=on_tokens,
on_attachment=reject_attachment,
on_mixed=on_mixed,
)
22 changes: 1 addition & 21 deletions aidial_adapter_bedrock/dial_api/request.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from enum import Enum
from typing import List, Literal, Optional
from typing import List, Optional

from aidial_sdk.chat_completion.request import ChatCompletionRequest
from pydantic import BaseModel
Expand All @@ -8,7 +7,6 @@
ToolsConfig,
validate_messages,
)
from aidial_adapter_bedrock.utils.pydantic import ExtraAllowModel


class ModelParameters(BaseModel):
Expand Down Expand Up @@ -46,21 +44,3 @@ def create(cls, request: ChatCompletionRequest) -> "ModelParameters":

def add_stop_sequences(self, stop: List[str]) -> "ModelParameters":
return self.copy(update={"stop": [*self.stop, *stop]})


class EmbeddingsType(str, Enum):
SYMMETRIC = "symmetric"
DOCUMENT = "document"
QUERY = "query"


class EmbeddingsRequest(ExtraAllowModel):
model: Optional[str] = None
input: str | List[str]
user: Optional[str] = None
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None

@staticmethod
def example() -> "EmbeddingsRequest":
return EmbeddingsRequest(input=["fish", "ball"])
46 changes: 10 additions & 36 deletions aidial_adapter_bedrock/dial_api/response.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List, Literal, Tuple, TypedDict
from typing import List, Literal

from aidial_sdk.embeddings import Embedding
from aidial_sdk.embeddings import Response as EmbeddingsResponse
from aidial_sdk.embeddings import Usage
from pydantic import BaseModel

from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage


class ModelObject(BaseModel):
object: Literal["model"] = "model"
Expand All @@ -15,40 +16,13 @@ class ModelsResponse(BaseModel):
data: List[ModelObject]


class EmbeddingsDict(TypedDict):
index: int
object: Literal["embedding"]
embedding: List[float]


class EmbeddingsTokenUsageDict(TypedDict):
prompt_tokens: int
total_tokens: int


class EmbeddingsResponseDict(TypedDict):
object: Literal["list"]
model: str
data: List[EmbeddingsDict]
usage: EmbeddingsTokenUsageDict


def make_embeddings_response(
model_id: str, resp: Tuple[List[List[float]], TokenUsage]
) -> EmbeddingsResponseDict:
vectors, usage = resp
model: str, vectors: List[List[float] | str], usage: Usage
) -> EmbeddingsResponse:

data: List[EmbeddingsDict] = [
{"index": idx, "object": "embedding", "embedding": vec}
for idx, vec in enumerate(vectors)
data: List[Embedding] = [
Embedding(index=index, embedding=embedding)
for index, embedding in enumerate(vectors)
]

return {
"object": "list",
"model": model_id,
"data": data,
"usage": {
"prompt_tokens": usage.prompt_tokens,
"total_tokens": usage.total_tokens,
},
}
return EmbeddingsResponse(model=model, data=data, usage=usage)
Loading

0 comments on commit fc1f4ae

Please sign in to comment.