Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Validate image size for Stable Diffusion 3 on adapter side #175

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aidial_adapter_bedrock/llm/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ async def get_bedrock_adapter(
model,
api_key,
image_to_image_supported=True,
image_width_constraints=(640, 1536),
image_height_constraints=(640, 1536),
)
case ChatCompletionDeployment.AMAZON_TITAN_TG1_LARGE:
return AmazonAdapter.create(
Expand Down
83 changes: 64 additions & 19 deletions aidial_adapter_bedrock/llm/model/stability/v2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional, assert_never
from io import BytesIO
from typing import List, Optional, Tuple, assert_never

from aidial_sdk.chat_completion import (
Message,
Expand All @@ -7,6 +8,8 @@
Role,
)
from aidial_sdk.chat_completion.request import ImageURL
from aidial_sdk.exceptions import RequestValidationError
from PIL import Image
from pydantic import BaseModel

from aidial_adapter_bedrock.bedrock import Bedrock
Expand All @@ -27,10 +30,51 @@
from aidial_adapter_bedrock.llm.model.stability.storage import save_to_storage
from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages
from aidial_adapter_bedrock.utils.json import remove_nones
from aidial_adapter_bedrock.utils.resource import Resource

SUPPORTED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/webp"]


def _validate_image_size(
image: Resource,
width_constraints: Tuple[int, int] | None,
height_constraints: Tuple[int, int] | None,
) -> None:
if width_constraints is None and height_constraints is None:
return

with Image.open(BytesIO(image.data)) as img:
width, height = img.size

for constraints, value, name in [
(width_constraints, width, "width"),
(height_constraints, height, "height"),
]:
if constraints is None:
continue
min_value, max_value = constraints
if not (min_value <= value <= max_value):
error_msg = (
f"Image {name} is {value}, but should be "
f"between {min_value} and {max_value}"
)
raise RequestValidationError(
message=error_msg,
display_message=error_msg,
code="invalid_argument",
)


def _validate_last_message(messages: List[Message]):
if not messages:
raise ValidationError("No messages provided")

last_message = messages[-1]
if last_message.role != Role.USER:
raise ValidationError("Last message must be from user")
return last_message


class StabilityV2Response(BaseModel):
seeds: List[int]
images: List[str]
Expand Down Expand Up @@ -74,6 +118,8 @@ class StabilityV2Adapter(ChatCompletionAdapter):
client: Bedrock
storage: Optional[FileStorage]
image_to_image_supported: bool
width_constraints: Tuple[int, int] | None
height_constraints: Tuple[int, int] | None

@classmethod
def create(
Expand All @@ -82,28 +128,23 @@ def create(
model: str,
api_key: str,
image_to_image_supported: bool,
image_width_constraints: Tuple[int, int] | None = None,
image_height_constraints: Tuple[int, int] | None = None,
):
storage: Optional[FileStorage] = create_file_storage(api_key)
return cls(
client=client,
model=model,
storage=storage,
image_to_image_supported=image_to_image_supported,
width_constraints=image_width_constraints,
height_constraints=image_height_constraints,
)

def _validate_last_message(self, messages: List[Message]):
if not messages:
raise ValidationError("No messages provided")

last_message = messages[-1]
if last_message.role != Role.USER:
raise ValidationError("Last message must be from user")
return last_message

async def compute_discarded_messages(
self, params: ModelParameters, messages: List[Message]
) -> DiscardedMessages | None:
self._validate_last_message(messages)
_validate_last_message(messages)
return list(range(len(messages) - 1))

async def chat(
Expand All @@ -115,7 +156,7 @@ async def chat(

text_prompt = None
image_resources: List[DialResource] = []
last_message = self._validate_last_message(messages)
last_message = _validate_last_message(messages)
# Handle text content
match last_message.content:
case str(text):
Expand Down Expand Up @@ -166,26 +207,30 @@ async def chat(
if len(image_resources) > 1:
raise ValidationError("Only one input image is supported")

if self.image_to_image_supported and image_resources:
image_resource = await image_resources[0].download(self.storage)
_validate_image_size(
image_resource, self.width_constraints, self.height_constraints
)
else:
image_resource = None

response, _ = await self.client.ainvoke_non_streaming(
self.model,
remove_nones(
{
"prompt": text_prompt,
"image": (
(
await image_resources[0].download(self.storage)
).data_base64
if image_resources
else None
image_resource.data_base64 if image_resource else None
),
"mode": (
"image-to-image" if image_resources else "text-to-image"
"image-to-image" if image_resource else "text-to-image"
),
"output_format": "png",
# This parameter controls how much input image will affect generation from 0 to 1,
# where 0 means that output will be identical to input image and 1 means that model will ignore input image
# Since there is no recommended default value, we use 0.5 as a middle ground
"strength": 0.5 if image_resources else None,
"strength": 0.5 if image_resource else None,
}
),
)
Expand Down
Binary file modified tests/integration_tests/images/dog-sample-image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 6 additions & 2 deletions tests/integration_tests/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,12 @@ async def test_image_to_image_with_too_small_picture(
model=deployment.value,
messages=[user_with_image_content_part("test", BLUE_PNG_PICTURE)],
)
assert exc_info.value.status_code == 400
assert "width must be between 640 and 1536" in exc_info.value.message

assert exc_info.value.status_code == 422
assert (
"Image width is 3, but should be between 640 and 1536"
in exc_info.value.message
)


@pytest.mark.parametrize(
Expand Down
Loading