diff --git a/aidial_adapter_bedrock/llm/model/stability/v2.py b/aidial_adapter_bedrock/llm/model/stability/v2.py index 2ae9c47..9b5e38a 100644 --- a/aidial_adapter_bedrock/llm/model/stability/v2.py +++ b/aidial_adapter_bedrock/llm/model/stability/v2.py @@ -17,6 +17,7 @@ from aidial_adapter_bedrock.dial_api.resource import ( AttachmentResource, DialResource, + UnsupportedContentType, URLResource, ) from aidial_adapter_bedrock.dial_api.storage import ( @@ -26,13 +27,26 @@ from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter from aidial_adapter_bedrock.llm.consumer import Attachment, Consumer -from aidial_adapter_bedrock.llm.errors import ValidationError +from aidial_adapter_bedrock.llm.errors import UserError, ValidationError from aidial_adapter_bedrock.llm.model.stability.storage import save_to_storage from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages from aidial_adapter_bedrock.utils.json import remove_nones from aidial_adapter_bedrock.utils.resource import Resource SUPPORTED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/webp"] +SUPPORTED_IMAGE_EXTENSIONS = ["jpeg", "jpe", "jpg", "png", "webp"] + + +async def _download_resource( + dial_resource: DialResource, storage: FileStorage | None +) -> Resource: + try: + return await dial_resource.download(storage) + except UnsupportedContentType as e: + raise UserError( + error_message=f"Unsupported image type: {e.type}", + usage_message=f"Supported image types: {', '.join(SUPPORTED_IMAGE_EXTENSIONS)}", + ) def _validate_image_size( @@ -202,20 +216,23 @@ async def chat( ) if not self.image_to_image_supported and image_resources: - raise ValidationError( - f"Image-to-image is not supported for {self.model}" - ) + raise UserError("Image-to-Image is not supported") if len(image_resources) > 1: - raise ValidationError("Only one input image is supported") + raise UserError("Only one input image is supported") if self.image_to_image_supported and image_resources: - image_resource = await image_resources[0].download(self.storage) + image_resource = await _download_resource( + image_resources[0], self.storage + ) _validate_image_size( image_resource, self.width_constraints, self.height_constraints ) else: image_resource = None + if not text_prompt: + raise UserError("Text prompt is required") + response, _ = await self.client.ainvoke_non_streaming( self.model, remove_nones(