Skip to content

Commit

Permalink
fix: reporting 429 error code during streaming in the "code" field (#138
Browse files Browse the repository at this point in the history
)
  • Loading branch information
adubovik authored Aug 13, 2024
1 parent 00a9fcc commit cb357ed
Showing 1 changed file with 48 additions and 4 deletions.
52 changes: 48 additions & 4 deletions aidial_adapter_bedrock/server/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

import json
from enum import Enum
from functools import wraps

from aidial_sdk import HTTPException as DialException
Expand All @@ -34,6 +35,42 @@ def get_exception_type(status_code: int) -> str:
return "internal_server_error"


class BedrockExceptionCode(Enum):
"""
See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModelWithResponseStream.html#API_runtime_InvokeModelWithResponseStream_ResponseSyntax
for the types of exceptions
"""

THROTTLING = "throttlingException"
MODEL_TIMEOUT = "modelTimeoutException"

def __eq__(self, other):
if isinstance(other, str):
return self.value.lower() == other.lower()
return NotImplemented


def _get_meta_status_code(response: dict) -> int | None:
code = response.get("ResponseMetadata", {}).get("HTTPStatusCode")
if isinstance(code, int):
return code
return None


def _get_response_error_code(response: dict) -> int | None:
code = response.get("Error", {}).get("Code")

if isinstance(code, str):
match code:
case BedrockExceptionCode.THROTTLING:
return 429
case BedrockExceptionCode.MODEL_TIMEOUT:
return 408
case _:
pass
return None


def to_dial_exception(e: Exception) -> DialException:
if (
isinstance(e, ClientError)
Expand All @@ -46,19 +83,24 @@ def to_dial_exception(e: Exception) -> DialException:
)

status_code = (
response.get("ResponseMetadata", {}).get("HTTPStatusCode") or 500
_get_response_error_code(response)
or _get_meta_status_code(response)
or 500
)

return DialException(
status_code=status_code,
code=str(status_code),
type=get_exception_type(status_code),
message=str(e),
)

if isinstance(e, APIStatusError):
status_code = e.status_code
return DialException(
status_code=e.status_code,
type=get_exception_type(e.status_code),
status_code=status_code,
code=str(status_code),
type=get_exception_type(status_code),
message=e.message,
)

Expand All @@ -71,8 +113,10 @@ def to_dial_exception(e: Exception) -> DialException:
if isinstance(e, DialException):
return e

status_code = 500
return DialException(
status_code=500,
status_code=status_code,
code=str(status_code),
type="internal_server_error",
message=str(e),
)
Expand Down

0 comments on commit cb357ed

Please sign in to comment.