diff --git a/aidial_adapter_bedrock/server/exceptions.py b/aidial_adapter_bedrock/server/exceptions.py index 990ada9..10f4f9f 100644 --- a/aidial_adapter_bedrock/server/exceptions.py +++ b/aidial_adapter_bedrock/server/exceptions.py @@ -18,6 +18,7 @@ """ import json +from enum import Enum from functools import wraps from aidial_sdk import HTTPException as DialException @@ -34,6 +35,42 @@ def get_exception_type(status_code: int) -> str: return "internal_server_error" +class BedrockExceptionCode(Enum): + """ + See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModelWithResponseStream.html#API_runtime_InvokeModelWithResponseStream_ResponseSyntax + for the types of exceptions + """ + + THROTTLING = "throttlingException" + MODEL_TIMEOUT = "modelTimeoutException" + + def __eq__(self, other): + if isinstance(other, str): + return self.value.lower() == other.lower() + return NotImplemented + + +def _get_meta_status_code(response: dict) -> int | None: + code = response.get("ResponseMetadata", {}).get("HTTPStatusCode") + if isinstance(code, int): + return code + return None + + +def _get_response_error_code(response: dict) -> int | None: + code = response.get("Error", {}).get("Code") + + if isinstance(code, str): + match code: + case BedrockExceptionCode.THROTTLING: + return 429 + case BedrockExceptionCode.MODEL_TIMEOUT: + return 408 + case _: + pass + return None + + def to_dial_exception(e: Exception) -> DialException: if ( isinstance(e, ClientError) @@ -46,19 +83,24 @@ def to_dial_exception(e: Exception) -> DialException: ) status_code = ( - response.get("ResponseMetadata", {}).get("HTTPStatusCode") or 500 + _get_response_error_code(response) + or _get_meta_status_code(response) + or 500 ) return DialException( status_code=status_code, + code=str(status_code), type=get_exception_type(status_code), message=str(e), ) if isinstance(e, APIStatusError): + status_code = e.status_code return DialException( - status_code=e.status_code, - type=get_exception_type(e.status_code), + status_code=status_code, + code=str(status_code), + type=get_exception_type(status_code), message=e.message, ) @@ -71,8 +113,10 @@ def to_dial_exception(e: Exception) -> DialException: if isinstance(e, DialException): return e + status_code = 500 return DialException( - status_code=500, + status_code=status_code, + code=str(status_code), type="internal_server_error", message=str(e), )