Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: bump aidial-sdk from 0.14.0 to 0.16.0; bump protobuf from 5.29.0 to 5.29.1 #190

Merged
merged 10 commits into from
Dec 6, 2024
5 changes: 2 additions & 3 deletions aidial_adapter_bedrock/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ async def chat_completion(self, request: Request, response: Response):
async def generate_response(usage: TokenUsage) -> None:
nonlocal discarded_messages

with response.create_choice() as choice:
consumer = ChoiceConsumer(choice=choice)
with ChoiceConsumer(response=response) as consumer:
if isinstance(model, TextCompletionAdapter):
consumer.set_tools_emulator(
model.tools_emulator(params.tool_config)
Expand All @@ -78,7 +77,7 @@ async def generate_response(usage: TokenUsage) -> None:
try:
await model.chat(consumer, params, request.messages)
except UserError as e:
await e.report_usage(choice)
await e.report_usage(consumer.choice)
await response.aflush()
raise e

Expand Down
48 changes: 34 additions & 14 deletions aidial_adapter_bedrock/llm/consumer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Optional, assert_never

from aidial_sdk.chat_completion import (
Attachment,
Choice,
FinishReason,
FunctionCall,
Response,
ToolCall,
)
from pydantic import BaseModel

from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.message import (
Expand All @@ -18,15 +20,6 @@
from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages


class Attachment(BaseModel):
type: str | None = None
title: str | None = None
data: str | None = None
url: str | None = None
reference_url: str | None = None
reference_type: str | None = None


class Consumer(ABC):
@abstractmethod
def append_content(self, content: str):
Expand Down Expand Up @@ -66,16 +59,43 @@ def has_function_call(self) -> bool:

class ChoiceConsumer(Consumer):
usage: TokenUsage
choice: Choice
response: Response
_choice: Optional[Choice]
discarded_messages: Optional[DiscardedMessages]
tools_emulator: Optional[ToolsEmulator]

def __init__(self, choice: Choice):
self.choice = choice
def __init__(self, response: Response):
self.response = response
self._choice = None
self.usage = TokenUsage()
self.discarded_messages = None
self.tools_emulator = None

@property
def choice(self) -> Choice:
if self._choice is None:
# Delay opening a choice to the very last moment
# so as to give opportunity for exceptions to bubble up to
# the level of HTTP response (instead of error objects in a stream).
choice = self._choice = self.response.create_choice()
choice.open()
return choice
else:
return self._choice

def __enter__(self):
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
traceback: TracebackType | None,
) -> bool | None:
if exc is None and self._choice is not None:
self._choice.close()
return False

def set_tools_emulator(self, tools_emulator: ToolsEmulator):
self.tools_emulator = tools_emulator

Expand Down Expand Up @@ -118,7 +138,7 @@ def append_content(self, content: str):
self._process_content(content)

def add_attachment(self, attachment: Attachment):
self.choice.add_attachment(**attachment.dict())
self.choice.add_attachment(attachment)

def add_usage(self, usage: TokenUsage):
self.usage.accumulate(usage)
Expand Down
3 changes: 2 additions & 1 deletion aidial_adapter_bedrock/llm/model/stability/v1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional

from aidial_sdk.chat_completion import Attachment
from pydantic import BaseModel, Field

from aidial_adapter_bedrock.bedrock import Bedrock
Expand All @@ -14,7 +15,7 @@
TextCompletionAdapter,
TextCompletionPrompt,
)
from aidial_adapter_bedrock.llm.consumer import Attachment, Consumer
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.llm.message import BaseMessage
from aidial_adapter_bedrock.llm.model.stability.storage import save_to_storage
Expand Down
3 changes: 2 additions & 1 deletion aidial_adapter_bedrock/llm/model/stability/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional, Tuple, assert_never

from aidial_sdk.chat_completion import (
Attachment,
Message,
MessageContentImagePart,
MessageContentTextPart,
Expand All @@ -26,7 +27,7 @@
)
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter
from aidial_adapter_bedrock.llm.consumer import Attachment, Consumer
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.errors import UserError, ValidationError
from aidial_adapter_bedrock.llm.model.stability.storage import save_to_storage
from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages
Expand Down
20 changes: 18 additions & 2 deletions aidial_adapter_bedrock/llm/tools/tools_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
FunctionChoice,
Message,
Role,
Tool,
ToolChoice,
)
from aidial_sdk.chat_completion.request import AzureChatCompletionRequest
from aidial_sdk.chat_completion.request import (
AzureChatCompletionRequest,
StaticTool,
)
from pydantic import BaseModel

from aidial_adapter_bedrock.llm.errors import ValidationError
Expand Down Expand Up @@ -108,6 +112,15 @@ def tool_choice_to_function_call(
case _:
return tool_choice

@staticmethod
def _get_function_from_tool(tool: Tool | StaticTool) -> Function:
if isinstance(tool, Tool):
return tool.function
elif isinstance(tool, StaticTool):
raise ValidationError("Static tools aren't supported")
else:
assert_never(tool)

@classmethod
def from_request(cls, request: AzureChatCompletionRequest) -> Self | None:
validate_messages(request)
Expand All @@ -118,7 +131,10 @@ def from_request(cls, request: AzureChatCompletionRequest) -> Self | None:
tool_ids = None

elif request.tools is not None:
functions = [tool.function for tool in request.tools]
functions = [
ToolsConfig._get_function_from_tool(tool)
for tool in request.tools
]
function_call = ToolsConfig.tool_choice_to_function_call(
request.tool_choice
)
Expand Down
44 changes: 31 additions & 13 deletions aidial_adapter_bedrock/server/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import json
from enum import Enum
from functools import wraps
from typing import assert_never

from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import InternalServerError, InvalidRequestError
Expand All @@ -44,14 +45,35 @@ class BedrockExceptionCode(Enum):
for the types of exceptions
"""

THROTTLING = "throttlingException"
INTERNAL_SERVER = "internalServerException"
MODEL_STREAM_ERROR = "modelStreamErrorException"
MODEL_TIMEOUT = "modelTimeoutException"
SERVER_UNAVAILABLE = "serviceUnavailableException"
THROTTLING = "throttlingException"
VALIDATION = "validationException"

def __eq__(self, other):
if isinstance(other, str):
return self.value.lower() == other.lower()
return NotImplemented

def get_status_code(self) -> int:
match self:
case BedrockExceptionCode.INTERNAL_SERVER:
return 500
case BedrockExceptionCode.MODEL_STREAM_ERROR:
return 424
case BedrockExceptionCode.MODEL_TIMEOUT:
return 408
case BedrockExceptionCode.SERVER_UNAVAILABLE:
return 503
case BedrockExceptionCode.THROTTLING:
return 429
case BedrockExceptionCode.VALIDATION:
return 400
case _:
assert_never(self)


def _get_meta_status_code(response: dict) -> int | None:
code = response.get("ResponseMetadata", {}).get("HTTPStatusCode")
Expand All @@ -62,16 +84,10 @@ def _get_meta_status_code(response: dict) -> int | None:

def _get_response_error_code(response: dict) -> int | None:
code = response.get("Error", {}).get("Code")

if isinstance(code, str):
match code:
case BedrockExceptionCode.THROTTLING:
return 429
case BedrockExceptionCode.MODEL_TIMEOUT:
return 408
case _:
pass
return None
try:
return BedrockExceptionCode(code).get_status_code()
except Exception:
return None


def _get_content_filter_error(response: dict) -> DialException | None:
Expand Down Expand Up @@ -125,9 +141,11 @@ async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
dial_exception = to_dial_exception(e)
log.exception(
f"caught exception: {type(e).__module__}.{type(e).__name__}"
f"Caught exception: {type(e).__module__}.{type(e).__name__}. "
f"The exception converted to the dial exception: {dial_exception!r}."
)
raise to_dial_exception(e) from e
raise dial_exception from e

return wrapper
Loading
Loading