Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Stable Diffusion V2 API #168

Merged
merged 14 commits into from
Nov 6, 2024
Merged
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ Note that a model supports `/truncate_prompt` endpoint if and only if it support
|Meta|Llama 3 Chat 8B Instruct|meta.llama3-8b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 2 Chat 70B|meta.llama2-70b-chat-v1|text-to-text|🟡|🟡|❌|
|Meta|Llama 2 Chat 13B|meta.llama2-13b-chat-v1|text-to-text|🟡|🟡|❌|
|Stability AI|SDXL 1.0|stability.stable-diffusion-xl-v1|text-to-image|❌|🟡|❌|
|Stability AI|Diffusion 1.0|stability.stable-diffusion-xl-v1|text-to-image|❌|🟡|❌|
|Stability AI|Diffusion 3|stability.sd3-large-v1:0|text-to-image / image-to-image|❌|🟡|❌|
|Amazon|Titan Text G1 - Express|amazon.titan-tg1-large|text-to-text|🟡|🟡|❌|
|AI21 Labs|Jurassic-2 Ultra|ai21.j2-jumbo-instruct|text-to-text|🟡|🟡|❌|
|AI21 Labs|Jurassic-2 Ultra v1|ai21.j2-ultra-v1|text-to-text|🟡|🟡|❌|
Expand Down Expand Up @@ -138,6 +139,7 @@ If you use DIAL Core load balancing mechanism, you can provide `extraData` upstr
```

Supported `extraData` fields:

- `region`
- `aws_access_key_id`
- `aws_secret_access_key`
Expand Down Expand Up @@ -191,4 +193,4 @@ To remove the virtual environment and build artifacts:

```sh
make clean
```
```
1 change: 1 addition & 0 deletions aidial_adapter_bedrock/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ChatCompletionDeployment(str, Enum):

STABILITY_STABLE_DIFFUSION_XL = "stability.stable-diffusion-xl"
STABILITY_STABLE_DIFFUSION_XL_V1 = "stability.stable-diffusion-xl-v1"
STABILITY_STABLE_DIFFUSION_3_LARGE_V1 = "stability.sd3-large-v1:0"

META_LLAMA2_13B_CHAT_V1 = "meta.llama2-13b-chat-v1"
META_LLAMA2_70B_CHAT_V1 = "meta.llama2-70b-chat-v1"
Expand Down
9 changes: 7 additions & 2 deletions aidial_adapter_bedrock/llm/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from aidial_adapter_bedrock.llm.model.llama.v2 import llama2_config
from aidial_adapter_bedrock.llm.model.llama.v3 import llama3_config
from aidial_adapter_bedrock.llm.model.meta import MetaAdapter
from aidial_adapter_bedrock.llm.model.stability import StabilityAdapter
from aidial_adapter_bedrock.llm.model.stabililty.v1 import StabilityV1Adapter
from aidial_adapter_bedrock.llm.model.stabililty.v3 import StabilityV3Adapter


async def get_bedrock_adapter(
Expand Down Expand Up @@ -76,7 +77,11 @@ async def get_bedrock_adapter(
ChatCompletionDeployment.STABILITY_STABLE_DIFFUSION_XL
| ChatCompletionDeployment.STABILITY_STABLE_DIFFUSION_XL_V1
):
return StabilityAdapter.create(
return StabilityV1Adapter.create(
await Bedrock.acreate(aws_client_config), model, api_key
)
case ChatCompletionDeployment.STABILITY_STABLE_DIFFUSION_3_LARGE_V1:
return StabilityV3Adapter.create(
await Bedrock.acreate(aws_client_config), model, api_key
)
case ChatCompletionDeployment.AMAZON_TITAN_TG1_LARGE:
Expand Down
Empty file.
22 changes: 22 additions & 0 deletions aidial_adapter_bedrock/llm/model/stabililty/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from aidial_adapter_bedrock.dial_api.storage import FileStorage
from aidial_adapter_bedrock.llm.consumer import Attachment


async def save_to_storage(
storage: FileStorage, attachment: Attachment
) -> Attachment:
if (
attachment.type is not None
and attachment.type.startswith("image/")
and attachment.data is not None
):
response = await storage.upload_file_as_base64(
"images", attachment.data, attachment.type
)
return Attachment(
title=attachment.title,
type=attachment.type,
url=response["url"],
)

return attachment
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def save_to_storage(
return attachment


class StabilityAdapter(TextCompletionAdapter):
class StabilityV1Adapter(TextCompletionAdapter):
model: str
client: Bedrock
storage: Optional[FileStorage]
Expand Down
177 changes: 177 additions & 0 deletions aidial_adapter_bedrock/llm/model/stabililty/v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from typing import List, Optional, assert_never

from aidial_sdk.chat_completion import (
Message,
MessageContentImagePart,
MessageContentTextPart,
Role,
)
from pydantic import BaseModel

from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.resource import (
AttachmentResource,
URLResource,
)
from aidial_adapter_bedrock.dial_api.storage import (
FileStorage,
create_file_storage,
)
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter
from aidial_adapter_bedrock.llm.consumer import Attachment, Consumer
from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.llm.model.stabililty.storage import save_to_storage
from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages
from aidial_adapter_bedrock.utils.json import remove_nones

SUPPORTED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/webp"]


class StabilityV3Response(BaseModel):
seeds: List[int]
images: List[str]
# None will indicate that the request was successful
# Possible values:
# "Filter reason: prompt"
# "Filter reason: output image"
# "Filter reason: input image"
# "Inference error"
# null
finish_reasons: List[Optional[str]]

def content(self) -> str:
return " "

def attachments(self) -> List[Attachment]:
return [
Attachment(
type="image/png",
data=image,
)
for image in self.images
]

def usage(self) -> TokenUsage:
return TokenUsage(prompt_tokens=0, completion_tokens=1)

def throw_if_error(self):
error = next((reason for reason in self.finish_reasons if reason), None)
if not error:
return

if error == "Inference error":
raise RuntimeError(error)
else:
raise ValidationError(error)


class StabilityV3Adapter(ChatCompletionAdapter):
model: str
client: Bedrock
storage: Optional[FileStorage]

@classmethod
def create(cls, client: Bedrock, model: str, api_key: str):
storage: Optional[FileStorage] = create_file_storage(api_key)
return cls(
client=client,
model=model,
storage=storage,
)

def _validate_last_message(self, messages: List[Message]):
if not messages:
raise ValidationError("No messages provided")

last_message = messages[-1]
if last_message.role != Role.USER:
raise ValidationError("Last message must be from user")
return last_message

async def compute_discarded_messages(
self, params: ModelParameters, messages: List[Message]
) -> DiscardedMessages | None:
self._validate_last_message(messages)
return list(range(len(messages) - 1))

async def chat(
self,
consumer: Consumer,
params: ModelParameters,
messages: List[Message],
) -> None:

text_prompt = None
image_data = None
last_message = self._validate_last_message(messages)
# Handle text content
match last_message.content:
case str(text):
text_prompt = text
case list():
text_parts = []

for part in last_message.content:
match part:
case MessageContentTextPart(text=text):
text_parts.append(text)
case MessageContentImagePart():
if image_data is not None:
raise ValidationError(
"Only one input image is supported"
)
resource = await URLResource(
url=part.image_url.url,
supported_types=SUPPORTED_IMAGE_TYPES,
).download(self.storage)
image_data = resource.data_base64
case _:
assert_never(part)
if text_parts:
text_prompt = " ".join(text_parts)
case None:
pass
case _:
assert_never(last_message.content)

if (
last_message.custom_content
and last_message.custom_content.attachments
):
if (
len(last_message.custom_content.attachments) > 1
or image_data is not None
):
raise ValidationError("Only one input image is supported")
resource = await AttachmentResource(
attachment=last_message.custom_content.attachments[0],
supported_types=SUPPORTED_IMAGE_TYPES,
).download(self.storage)
image_data = resource.data_base64

response, _ = await self.client.ainvoke_non_streaming(
self.model,
remove_nones(
{
"prompt": text_prompt,
"image": image_data,
"mode": "image-to-image" if image_data else "text-to-image",
"output_format": "png",
}
),
)

stability_response = StabilityV3Response.parse_obj(response)
stability_response.throw_if_error()

consumer.append_content(stability_response.content())
consumer.close_content()

consumer.add_usage(stability_response.usage())

for attachment in stability_response.attachments():
if self.storage:
attachment = await save_to_storage(self.storage, attachment)
consumer.add_attachment(attachment)
6 changes: 6 additions & 0 deletions tests/integration_tests/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from aidial_adapter_bedrock.utils.resource import Resource

BLUE_PNG_PICTURE = Resource.from_base64(
type="image/png",
data_base64="iVBORw0KGgoAAAANSUhEUgAAAAMAAAADCAIAAADZSiLoAAAAF0lEQVR4nGNkYPjPwMDAwMDAxAADCBYAG10BBdmz9y8AAAAASUVORK5CYII=",
)
14 changes: 4 additions & 10 deletions tests/integration_tests/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
UpstreamConfig,
)
from aidial_adapter_bedrock.deployments import ChatCompletionDeployment
from aidial_adapter_bedrock.utils.resource import Resource
from tests.integration_tests.constants import BLUE_PNG_PICTURE
from tests.utils.openai import (
GET_WEATHER_FUNCTION,
ChatCompletionResult,
Expand Down Expand Up @@ -180,12 +180,6 @@ def are_tools_emulated(deployment: ChatCompletionDeployment) -> bool:
]


blue_pic = Resource.from_base64(
type="image/png",
data_base64="iVBORw0KGgoAAAANSUhEUgAAAAMAAAADCAIAAADZSiLoAAAAF0lEQVR4nGNkYPjPwMDAwMDAxAADCBYAG10BBdmz9y8AAAAASUVORK5CYII=",
)


def get_test_cases(
deployment: ChatCompletionDeployment, region: str, streaming: bool
) -> List[TestCase]:
Expand Down Expand Up @@ -330,9 +324,9 @@ def dial_recall_expected(r: ChatCompletionResult):
content = "describe the image"
for idx, user_message in enumerate(
[
user_with_attachment_data(content, blue_pic),
user_with_attachment_url(content, blue_pic),
user_with_image_url(content, blue_pic),
user_with_attachment_data(content, BLUE_PNG_PICTURE),
user_with_attachment_url(content, BLUE_PNG_PICTURE),
user_with_image_url(content, BLUE_PNG_PICTURE),
]
):
test_case(
Expand Down
Loading