diff --git a/aidial_adapter_bedrock/dial_api/embedding_inputs.py b/aidial_adapter_bedrock/dial_api/embedding_inputs.py index 82bd573..887f211 100644 --- a/aidial_adapter_bedrock/dial_api/embedding_inputs.py +++ b/aidial_adapter_bedrock/dial_api/embedding_inputs.py @@ -9,7 +9,7 @@ cast, ) -from aidial_sdk.chat_completion.request import Attachment +from aidial_sdk.chat_completion import Attachment from aidial_sdk.embeddings.request import EmbeddingsRequest from aidial_adapter_bedrock.llm.errors import ValidationError diff --git a/aidial_adapter_bedrock/dial_api/request.py b/aidial_adapter_bedrock/dial_api/request.py index a753343..a567e71 100644 --- a/aidial_adapter_bedrock/dial_api/request.py +++ b/aidial_adapter_bedrock/dial_api/request.py @@ -1,14 +1,27 @@ -from typing import List, Optional +from typing import List, Optional, TypeGuard, assert_never +from aidial_sdk.chat_completion import ( + MessageContentImagePart, + MessageContentPart, + MessageContentTextPart, +) from aidial_sdk.chat_completion.request import ChatCompletionRequest from pydantic import BaseModel +from aidial_adapter_bedrock.llm.errors import ValidationError from aidial_adapter_bedrock.llm.tools.tools_config import ( ToolsConfig, ToolsMode, validate_messages, ) +MessageContent = str | List[MessageContentPart] | None +MessageContentSpecialized = ( + MessageContent + | List[MessageContentTextPart] + | List[MessageContentImagePart] +) + class ModelParameters(BaseModel): temperature: Optional[float] = None @@ -51,3 +64,55 @@ def tools_mode(self) -> ToolsMode | None: if self.tool_config is not None: return self.tool_config.tools_mode return None + + +def collect_text_content( + content: MessageContentSpecialized, delimiter: str = "\n\n" +) -> str: + + if content is None: + return "" + + if isinstance(content, str): + return content + + texts: List[str] = [] + for part in content: + if isinstance(part, MessageContentTextPart): + texts.append(part.text) + else: + raise ValidationError( + "Can't extract text from a multi-modal content part" + ) + + return delimiter.join(texts) + + +def to_message_content(content: MessageContentSpecialized) -> MessageContent: + match content: + case None | str(): + return content + case list(): + return [*content] + case _: + assert_never(content) + + +def is_text_content( + content: MessageContent, +) -> TypeGuard[str | List[MessageContentTextPart]]: + match content: + case None: + return False + case str(): + return True + case list(): + return all( + isinstance(part, MessageContentTextPart) for part in content + ) + case _: + assert_never(content) + + +def is_plain_text_content(content: MessageContent) -> TypeGuard[str | None]: + return content is None or isinstance(content, str) diff --git a/aidial_adapter_bedrock/dial_api/resource.py b/aidial_adapter_bedrock/dial_api/resource.py new file mode 100644 index 0000000..57f04a1 --- /dev/null +++ b/aidial_adapter_bedrock/dial_api/resource.py @@ -0,0 +1,181 @@ +import base64 +import mimetypes +from abc import ABC, abstractmethod +from typing import List + +from aidial_sdk.chat_completion import Attachment +from pydantic import BaseModel, Field, root_validator, validator + +from aidial_adapter_bedrock.dial_api.storage import FileStorage, download_file +from aidial_adapter_bedrock.utils.resource import Resource +from aidial_adapter_bedrock.utils.text import truncate_string + + +class ValidationError(Exception): + message: str + + def __init__(self, message: str): + self.message = message + super().__init__(message) + + +class MissingContentType(ValidationError): + pass + + +class UnsupportedContentType(ValidationError): + type: str + supported_types: List[str] + + def __init__(self, *, message: str, type: str, supported_types: List[str]): + self.type = type + self.supported_types = supported_types + super().__init__(message) + + +class DialResource(ABC, BaseModel): + entity_name: str = Field(default=None) + supported_types: List[str] | None = Field(default=None) + + @abstractmethod + async def download(self, storage: FileStorage | None) -> Resource: ... + + @abstractmethod + async def guess_content_type(self) -> str | None: ... + + @abstractmethod + async def get_resource_name(self, storage: FileStorage | None) -> str: ... + + async def get_content_type(self) -> str: + type = await self.guess_content_type() + + if not type: + raise MissingContentType( + f"Can't derive content type of the {self.entity_name}" + ) + + if ( + self.supported_types is not None + and type not in self.supported_types + ): + raise UnsupportedContentType( + message=f"The {self.entity_name} is not one of the supported types", + type=type, + supported_types=self.supported_types, + ) + + return type + + +class URLResource(DialResource): + url: str + content_type: str | None = None + + @root_validator + def validator(cls, values): + values["entity_name"] = values.get("entity_name") or "URL" + return values + + async def download(self, storage: FileStorage | None) -> Resource: + type = await self.get_content_type() + data = await _download_url(storage, self.url) + return Resource(type=type, data=data) + + async def guess_content_type(self) -> str | None: + return ( + self.content_type + or Resource.parse_data_url_content_type(self.url) + or mimetypes.guess_type(self.url)[0] + ) + + def is_data_url(self) -> bool: + return Resource.parse_data_url_content_type(self.url) is not None + + async def get_resource_name(self, storage: FileStorage | None) -> str: + if self.is_data_url(): + return f"data URL ({await self.guess_content_type()})" + + name = self.url + if storage is not None: + name = await storage.get_human_readable_name(self.url) + + return truncate_string(name, n=50) + + +class AttachmentResource(DialResource): + attachment: Attachment + + @validator("attachment", pre=True) + def parse_attachment(cls, value): + if isinstance(value, dict): + attachment = Attachment.parse_obj(value) + # Working around the issue of defaulting missing type to a markdown: + # https://github.com/epam/ai-dial-sdk/blob/2835107e950c89645a2b619fecba2518fa2d7bb1/aidial_sdk/chat_completion/request.py#L22 + if "type" not in value: + attachment.type = None + return attachment + return value + + @root_validator(pre=True) + def validator(cls, values): + values["entity_name"] = values.get("entity_name") or "attachment" + return values + + async def download(self, storage: FileStorage | None) -> Resource: + type = await self.get_content_type() + + if self.attachment.data: + data = base64.b64decode(self.attachment.data) + elif self.attachment.url: + data = await _download_url(storage, self.attachment.url) + else: + raise ValidationError(f"Invalid {self.entity_name}") + + return Resource(type=type, data=data) + + def create_url_resource(self, url: str) -> URLResource: + return URLResource( + url=url, + content_type=self.informative_content_type, + entity_name=self.entity_name, + ) + + @property + def informative_content_type(self) -> str | None: + if ( + self.attachment.type is None + or "octet-stream" in self.attachment.type + ): + return None + return self.attachment.type + + async def guess_content_type(self) -> str | None: + if url := self.attachment.url: + type = await self.create_url_resource(url).guess_content_type() + if type: + return type + + return self.attachment.type + + async def get_resource_name(self, storage: FileStorage | None) -> str: + if title := self.attachment.title: + return title + + if self.attachment.data: + return f"data {self.entity_name}" + elif url := self.attachment.url: + return await self.create_url_resource(url).get_resource_name( + storage + ) + else: + raise ValidationError(f"Invalid {self.entity_name}") + + +async def _download_url(file_storage: FileStorage | None, url: str) -> bytes: + if (resource := Resource.from_data_url(url)) is not None: + return resource.data + + if file_storage: + return await file_storage.download_file(url) + else: + return await download_file(url) diff --git a/aidial_adapter_bedrock/dial_api/storage.py b/aidial_adapter_bedrock/dial_api/storage.py index 3734258..9095abc 100644 --- a/aidial_adapter_bedrock/dial_api/storage.py +++ b/aidial_adapter_bedrock/dial_api/storage.py @@ -4,11 +4,12 @@ import mimetypes import os from typing import Mapping, Optional, TypedDict -from urllib.parse import urljoin +from urllib.parse import unquote, urljoin import aiohttp +from pydantic import BaseModel -from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log +from aidial_adapter_bedrock.utils.log_config import app_logger as log class FileMetadata(TypedDict): @@ -23,15 +24,10 @@ class Bucket(TypedDict): appdata: str -class FileStorage: +class FileStorage(BaseModel): dial_url: str api_key: str - bucket: Optional[Bucket] - - def __init__(self, dial_url: str, api_key: str): - self.dial_url = dial_url - self.api_key = api_key - self.bucket = None + bucket: Optional[Bucket] = None @property def auth_headers(self) -> Mapping[str, str]: @@ -49,6 +45,15 @@ async def _get_bucket(self, session: aiohttp.ClientSession) -> Bucket: return self.bucket + async def _get_user_bucket(self, session: aiohttp.ClientSession) -> str: + bucket = await self._get_bucket(session) + appdata = bucket.get("appdata") + if appdata is None: + raise ValueError( + "Can't retrieve user bucket because appdata isn't available" + ) + return appdata.split("/", 1)[0] + @staticmethod def _to_form_data( filename: str, content_type: str, content: bytes @@ -87,36 +92,48 @@ async def upload( async def upload_file_as_base64( self, upload_dir: str, data: str, content_type: str ) -> FileMetadata: - filename = f"{upload_dir}/{_compute_hash_digest(data)}" + filename = f"{upload_dir}/{compute_hash_digest(data)}" content: bytes = base64.b64decode(data) return await self.upload(filename, content_type, content) - async def download_file_as_base64(self, dial_path: str) -> str: - url = urljoin(f"{self.dial_url}/v1/", dial_path) + def attachment_link_to_url(self, link: str) -> str: + return urljoin(f"{self.dial_url}/v1/", link) + + def _url_to_attachment_link(self, url: str) -> str: + return url.removeprefix(f"{self.dial_url}/v1/") + + async def download_file(self, link: str) -> bytes: + url = self.attachment_link_to_url(link) headers: Mapping[str, str] = {} if url.lower().startswith(self.dial_url.lower()): headers = self.auth_headers + return await download_file(url, headers) + + async def get_human_readable_name(self, link: str) -> str: + url = self.attachment_link_to_url(link) + link = self._url_to_attachment_link(url) - return await download_file_as_base64(url, headers) + link = link.removeprefix("files/") + if link.startswith("public/"): + bucket = "public" + else: + async with aiohttp.ClientSession() as session: + bucket = await self._get_user_bucket(session) -async def _download_file( - url: str, headers: Optional[Mapping[str, str]] -) -> bytes: + link = link.removeprefix(f"{bucket}/") + decoded_link = unquote(link) + return link if link == decoded_link else repr(decoded_link) + + +async def download_file(url: str, headers: Mapping[str, str] = {}) -> bytes: async with aiohttp.ClientSession() as session: async with session.get(url, headers=headers) as response: response.raise_for_status() return await response.read() -async def download_file_as_base64( - url: str, headers: Optional[Mapping[str, str]] = None -) -> str: - data = await _download_file(url, headers) - return base64.b64encode(data).decode("ascii") - - -def _compute_hash_digest(file_content: str) -> str: +def compute_hash_digest(file_content: str) -> str: return hashlib.sha256(file_content.encode()).hexdigest() diff --git a/aidial_adapter_bedrock/embedding/amazon/titan_image.py b/aidial_adapter_bedrock/embedding/amazon/titan_image.py index a13f5eb..87a65f4 100644 --- a/aidial_adapter_bedrock/embedding/amazon/titan_image.py +++ b/aidial_adapter_bedrock/embedding/amazon/titan_image.py @@ -7,7 +7,7 @@ from typing import AsyncIterator, List, Self -from aidial_sdk.chat_completion.request import Attachment +from aidial_sdk.chat_completion import Attachment from aidial_sdk.embeddings import Response as EmbeddingsResponse from aidial_sdk.embeddings import Usage from aidial_sdk.embeddings.request import EmbeddingsRequest @@ -18,6 +18,7 @@ EMPTY_INPUT_LIST_ERROR, collect_embedding_inputs, ) +from aidial_adapter_bedrock.dial_api.resource import AttachmentResource from aidial_adapter_bedrock.dial_api.response import make_embeddings_response from aidial_adapter_bedrock.dial_api.storage import ( FileStorage, @@ -26,7 +27,6 @@ from aidial_adapter_bedrock.embedding.amazon.response import ( call_embedding_model, ) -from aidial_adapter_bedrock.embedding.attachments import download_base64_data from aidial_adapter_bedrock.embedding.embeddings_adapter import ( EmbeddingsAdapter, ) @@ -34,9 +34,11 @@ from aidial_adapter_bedrock.embedding.validation import ( validate_embeddings_request, ) -from aidial_adapter_bedrock.llm.errors import ValidationError +from aidial_adapter_bedrock.llm.errors import UserError, ValidationError from aidial_adapter_bedrock.utils.json import remove_nones +IMAGE_MEDIA_TYPES = ["image/png"] + class AmazonRequest(BaseModel): inputText: str | None = None @@ -63,25 +65,29 @@ def create_titan_request( ) -async def download_image( - attachment: Attachment, storage: FileStorage | None -) -> str: - _content_type, data = await download_base64_data( - attachment, storage, ["image/png"] - ) - return data +def _validate_content_type(content_type: str, supported_types: List[str]): + if content_type not in supported_types: + raise UserError( + f"Unsupported attachment content type: {content_type}. " + f"Supported attachment types: {', '.join(supported_types)}." + ) def get_requests( - request: EmbeddingsRequest, storage: FileStorage | None + file_storage: FileStorage | None, request: EmbeddingsRequest ) -> AsyncIterator[AmazonRequest]: + async def download_image(attachment: Attachment) -> str: + resource = await AttachmentResource(attachment=attachment).download( + file_storage + ) + _validate_content_type(resource.type, IMAGE_MEDIA_TYPES) + return resource.data_base64 + async def on_text(text: str) -> AmazonRequest: return AmazonRequest(inputText=text) async def on_attachment(attachment: Attachment) -> AmazonRequest: - return AmazonRequest( - inputImage=await download_image(attachment, storage) - ) + return AmazonRequest(inputImage=await download_image(attachment)) async def on_text_or_attachment(text: str | Attachment) -> AmazonRequest: if isinstance(text, str): @@ -98,14 +104,14 @@ async def on_mixed(inputs: List[str | Attachment]) -> AmazonRequest: if isinstance(inputs[0], str) and isinstance(inputs[1], Attachment): return AmazonRequest( inputText=inputs[0], - inputImage=await download_image(inputs[1], storage), + inputImage=await download_image(inputs[1]), ) elif isinstance(inputs[0], Attachment) and isinstance( inputs[1], str ): return AmazonRequest( inputText=inputs[1], - inputImage=await download_image(inputs[0], storage), + inputImage=await download_image(inputs[0]), ) else: raise ValidationError( @@ -153,7 +159,8 @@ async def embeddings( token_count = 0 # NOTE: Amazon Titan doesn't support batched inputs - async for sub_request in get_requests(request, self.storage): + # TODO: create multiple tasks + async for sub_request in get_requests(self.storage, request): embedding, text_tokens = await call_embedding_model( self.client, self.model, diff --git a/aidial_adapter_bedrock/embedding/attachments.py b/aidial_adapter_bedrock/embedding/attachments.py deleted file mode 100644 index 7852da6..0000000 --- a/aidial_adapter_bedrock/embedding/attachments.py +++ /dev/null @@ -1,57 +0,0 @@ -import mimetypes -from typing import List, Optional, Tuple - -from aidial_sdk.chat_completion import Attachment - -from aidial_adapter_bedrock.dial_api.storage import ( - FileStorage, - download_file_as_base64, -) -from aidial_adapter_bedrock.llm.errors import UserError, ValidationError - - -async def _download_base64_data( - url: str, file_storage: Optional[FileStorage] -) -> str: - if not file_storage: - return await download_file_as_base64(url) - return await file_storage.download_file_as_base64(url) - - -def _validate_content_type( - content_type: str, supported_content_types: List[str] -): - if content_type not in supported_content_types: - raise UserError( - f"Unsupported attachment type: {content_type}. " - f"Supported attachment types: {', '.join(supported_content_types)}.", - ) - - -async def download_base64_data( - attachment: Attachment, - file_storage: Optional[FileStorage], - supported_content_types: List[str], -) -> Tuple[str, str]: - if attachment.data: - if not attachment.type: - raise ValidationError( - "Attachment type is required for provided data" - ) - _validate_content_type(attachment.type, supported_content_types) - return attachment.type, attachment.data - - if attachment.url: - content_type = ( - attachment.type or mimetypes.guess_type(attachment.url)[0] - ) - if not content_type: - raise ValidationError( - f"Cannot guess content type of attachment {attachment.url}" - ) - _validate_content_type(content_type, supported_content_types) - - data = await _download_base64_data(attachment.url, file_storage) - return content_type, data - - raise ValidationError("Attachment data or URL is required") diff --git a/aidial_adapter_bedrock/llm/chat_emulator.py b/aidial_adapter_bedrock/llm/chat_emulator.py index 4b4848a..0530d33 100644 --- a/aidial_adapter_bedrock/llm/chat_emulator.py +++ b/aidial_adapter_bedrock/llm/chat_emulator.py @@ -58,7 +58,7 @@ def _format_message(self, message: BaseMessage, idx: int) -> str: else: cue_prefix = cue + " " - return (cue_prefix + message.content.lstrip()).rstrip() + return (cue_prefix + message.text_content.lstrip()).rstrip() def get_ai_cue(self) -> Optional[str]: return self.cues["ai"] @@ -69,7 +69,7 @@ def display(self, messages: List[BaseMessage]) -> Tuple[str, List[str]]: and len(messages) == 1 and isinstance(messages[0], HumanRegularMessage) ): - return messages[0].content, [] + return messages[0].text_content, [] ret: List[str] = [] diff --git a/aidial_adapter_bedrock/llm/chat_model.py b/aidial_adapter_bedrock/llm/chat_model.py index 567988a..b61ea6b 100644 --- a/aidial_adapter_bedrock/llm/chat_model.py +++ b/aidial_adapter_bedrock/llm/chat_model.py @@ -6,7 +6,10 @@ from typing_extensions import override import aidial_adapter_bedrock.utils.stream as stream_utils -from aidial_adapter_bedrock.dial_api.request import ModelParameters +from aidial_adapter_bedrock.dial_api.request import ( + ModelParameters, + collect_text_content, +) from aidial_adapter_bedrock.llm.chat_emulator import ChatEmulator from aidial_adapter_bedrock.llm.consumer import Consumer from aidial_adapter_bedrock.llm.errors import ValidationError @@ -24,8 +27,7 @@ def _is_empty_system_message(msg: Message) -> bool: return ( msg.role == Role.SYSTEM - and msg.content is not None - and msg.content.strip() == "" + and collect_text_content(msg.content).strip() == "" ) diff --git a/aidial_adapter_bedrock/llm/message.py b/aidial_adapter_bedrock/llm/message.py index dfcc39c..1683480 100644 --- a/aidial_adapter_bedrock/llm/message.py +++ b/aidial_adapter_bedrock/llm/message.py @@ -1,95 +1,260 @@ -from typing import List, Optional, Union +from abc import ABC, abstractmethod +from typing import List, Optional, Self, Union +from aidial_sdk.chat_completion import Attachment, CustomContent, FunctionCall +from aidial_sdk.chat_completion import Message as DialMessage from aidial_sdk.chat_completion import ( - CustomContent, - FunctionCall, - Message, + MessageContentPart, + MessageContentTextPart, Role, ToolCall, ) from pydantic import BaseModel +from aidial_adapter_bedrock.dial_api.request import ( + collect_text_content, + is_plain_text_content, + is_text_content, + to_message_content, +) from aidial_adapter_bedrock.llm.errors import ValidationError -class SystemMessage(BaseModel): - content: str +class MessageABC(ABC, BaseModel): + @abstractmethod + def to_message(self) -> DialMessage: ... - def to_message(self) -> Message: - return Message(role=Role.SYSTEM, content=self.content) + @classmethod + @abstractmethod + def from_message(cls, message: DialMessage) -> Self | None: ... -class HumanRegularMessage(BaseModel): - content: str +class BaseMessageABC(MessageABC): + @property + @abstractmethod + def text_content(self) -> str: ... + + +class SystemMessage(BaseMessageABC): + content: str | List[MessageContentTextPart] + + def to_message(self) -> DialMessage: + return DialMessage( + role=Role.SYSTEM, + content=to_message_content(self.content), + ) + + @classmethod + def from_message(cls, message: DialMessage) -> Self | None: + if message.role != Role.SYSTEM: + return None + + content = message.content + + if not is_text_content(content): + raise ValidationError( + "System message is expected to be a string or a list of text content parts" + ) + + return cls(content=content) + + @property + def text_content(self) -> str: + return collect_text_content(self.content) + + +class HumanRegularMessage(BaseMessageABC): + """MM stands for multi-modal""" + + content: str | List[MessageContentPart] custom_content: Optional[CustomContent] = None - def to_message(self) -> Message: - return Message( + def to_message(self) -> DialMessage: + return DialMessage( role=Role.USER, content=self.content, custom_content=self.custom_content, ) + @classmethod + def from_message(cls, message: DialMessage) -> Self | None: + if message.role != Role.USER: + return None -class HumanToolResultMessage(BaseModel): + content = message.content + if content is None: + raise ValidationError( + "User message is expected to have content field" + ) + + return cls(content=content, custom_content=message.custom_content) + + @property + def text_content(self) -> str: + return collect_text_content(self.content) + + @property + def attachments(self) -> List[Attachment]: + return ( + self.custom_content.attachments or [] if self.custom_content else [] + ) + + +class HumanToolResultMessage(MessageABC): id: str content: str - def to_message(self) -> Message: - return Message( + def to_message(self) -> DialMessage: + return DialMessage( role=Role.TOOL, tool_call_id=self.id, content=self.content, ) + @classmethod + def from_message(cls, message: DialMessage) -> Self | None: + if message.role != Role.TOOL: + return None -class HumanFunctionResultMessage(BaseModel): + if not is_plain_text_content(message.content): + raise ValidationError( + "The tool message shouldn't contain content parts" + ) + + if message.content is None or message.tool_call_id is None: + raise ValidationError( + "The tool message is expected to have content and tool_call_id fields" + ) + + return cls(id=message.tool_call_id, content=message.content) + + +class HumanFunctionResultMessage(MessageABC): name: str content: str - def to_message(self) -> Message: - return Message( + def to_message(self) -> DialMessage: + return DialMessage( role=Role.FUNCTION, name=self.name, content=self.content, ) + @classmethod + def from_message(cls, message: DialMessage) -> Self | None: + if message.role != Role.FUNCTION: + return None + + if not is_plain_text_content(message.content): + raise ValidationError( + "The function message shouldn't contain content parts" + ) -class AIRegularMessage(BaseModel): + if message.content is None or message.name is None: + raise ValidationError( + "The function message is expected to have content and name fields" + ) + + return cls(name=message.name, content=message.content) + + +class AIRegularMessage(BaseMessageABC): content: str custom_content: Optional[CustomContent] = None - def to_message(self) -> Message: - return Message( + def to_message(self) -> DialMessage: + return DialMessage( role=Role.ASSISTANT, content=self.content, custom_content=self.custom_content, ) + @classmethod + def from_message(cls, message: DialMessage) -> Self | None: + if message.role != Role.ASSISTANT: + return None + + if message.function_call is not None or message.tool_calls is not None: + return None + + if not is_plain_text_content(message.content): + raise ValidationError( + "The assistant message shouldn't contain content parts" + ) + + if message.content is None: + raise ValidationError( + "The assistant message is expected to have content" + ) + + return cls( + content=message.content, custom_content=message.custom_content + ) + + @property + def text_content(self) -> str: + return self.content + + @property + def attachments(self) -> List[Attachment]: + return ( + self.custom_content.attachments or [] if self.custom_content else [] + ) + -class AIToolCallMessage(BaseModel): +class AIToolCallMessage(MessageABC): calls: List[ToolCall] content: Optional[str] = None - def to_message(self) -> Message: - return Message( + def to_message(self) -> DialMessage: + return DialMessage( role=Role.ASSISTANT, content=self.content, tool_calls=self.calls, ) + @classmethod + def from_message(cls, message: DialMessage) -> Self | None: + if message.role != Role.ASSISTANT: + return None -class AIFunctionCallMessage(BaseModel): + if message.tool_calls is None or message.function_call is not None: + return None + + if not is_plain_text_content(message.content): + raise ValidationError( + "The assistant message with tool calls shouldn't contain content parts" + ) + + return cls(calls=message.tool_calls, content=message.content) + + +class AIFunctionCallMessage(MessageABC): call: FunctionCall content: Optional[str] = None - def to_message(self) -> Message: - return Message( + def to_message(self) -> DialMessage: + return DialMessage( role=Role.ASSISTANT, content=self.content, function_call=self.call, ) + @classmethod + def from_message(cls, message: DialMessage) -> Self | None: + if message.role != Role.ASSISTANT: + return None + + if message.function_call is None or message.tool_calls is not None: + return None + + if not is_plain_text_content(message.content): + raise ValidationError( + "The assistant message with function call shouldn't contain content parts" + ) + + return cls(call=message.function_call, content=message.content) + BaseMessage = Union[SystemMessage, HumanRegularMessage, AIRegularMessage] @@ -101,51 +266,19 @@ def to_message(self) -> Message: ] -def _parse_assistant_message( - content: Optional[str], - function_call: Optional[FunctionCall], - tool_calls: Optional[List[ToolCall]], - custom_content: Optional[CustomContent], -) -> BaseMessage | ToolMessage: - if content is not None and function_call is None and tool_calls is None: - return AIRegularMessage(content=content, custom_content=custom_content) - - if function_call is not None and tool_calls is None: - return AIFunctionCallMessage(call=function_call, content=content) - - if function_call is None and tool_calls is not None: - return AIToolCallMessage(calls=tool_calls, content=content) +def parse_dial_message(msg: DialMessage) -> BaseMessage | ToolMessage: - raise ValidationError("Unknown type of assistant message") + message = ( + SystemMessage.from_message(msg) + or HumanRegularMessage.from_message(msg) + or HumanToolResultMessage.from_message(msg) + or HumanFunctionResultMessage.from_message(msg) + or AIRegularMessage.from_message(msg) + or AIToolCallMessage.from_message(msg) + or AIFunctionCallMessage.from_message(msg) + ) + if message is None: + raise ValidationError("Unknown message type or invalid message") -def parse_dial_message(msg: Message) -> BaseMessage | ToolMessage: - match msg: - case Message(role=Role.SYSTEM, content=content) if content is not None: - return SystemMessage(content=content) - case Message( - role=Role.USER, content=content, custom_content=custom_content - ) if content is not None: - return HumanRegularMessage( - content=content, custom_content=custom_content - ) - case Message( - role=Role.ASSISTANT, - content=content, - function_call=function_call, - tool_calls=tool_calls, - custom_content=custom_content, - ): - return _parse_assistant_message( - content, function_call, tool_calls, custom_content - ) - case Message( - role=Role.FUNCTION, name=name, content=content - ) if content is not None and name is not None: - return HumanFunctionResultMessage(name=name, content=content) - case Message( - role=Role.TOOL, tool_call_id=id, content=content - ) if content is not None and id is not None: - return HumanToolResultMessage(id=id, content=content) - case _: - raise ValidationError("Unknown message type or invalid message") + return message diff --git a/aidial_adapter_bedrock/llm/model/claude/v3/converters.py b/aidial_adapter_bedrock/llm/model/claude/v3/converters.py index edfb53c..d3dabf9 100644 --- a/aidial_adapter_bedrock/llm/model/claude/v3/converters.py +++ b/aidial_adapter_bedrock/llm/model/claude/v3/converters.py @@ -1,11 +1,11 @@ import json -import mimetypes -from typing import List, Literal, Optional, Set, Tuple, assert_never, cast +from typing import List, Literal, Optional, Tuple, assert_never, cast from aidial_sdk.chat_completion import ( - Attachment, FinishReason, Function, + MessageContentImagePart, + MessageContentTextPart, ToolCall, ) from anthropic.types import ( @@ -18,10 +18,13 @@ ) from anthropic.types.image_block_param import Source -from aidial_adapter_bedrock.dial_api.storage import ( - FileStorage, - download_file_as_base64, +from aidial_adapter_bedrock.dial_api.resource import ( + AttachmentResource, + DialResource, + UnsupportedContentType, + URLResource, ) +from aidial_adapter_bedrock.dial_api.storage import FileStorage from aidial_adapter_bedrock.llm.errors import UserError, ValidationError from aidial_adapter_bedrock.llm.message import ( AIRegularMessage, @@ -32,87 +35,92 @@ SystemMessage, ) from aidial_adapter_bedrock.llm.tools.tools_config import ToolsMode +from aidial_adapter_bedrock.utils.resource import Resource ClaudeFinishReason = Literal[ "end_turn", "max_tokens", "stop_sequence", "tool_use" ] ImageMediaType = Literal["image/png", "image/jpeg", "image/gif", "image/webp"] -IMAGE_MEDIA_TYPES: Set[ImageMediaType] = { +IMAGE_MEDIA_TYPES: List[str] = [ "image/png", "image/jpeg", "image/gif", "image/webp", -} +] FILE_EXTENSIONS = ["png", "jpeg", "jpg", "gif", "webp"] -def _validate_media_type(media_type: str) -> ImageMediaType: - if media_type not in IMAGE_MEDIA_TYPES: - raise UserError( - f"Unsupported media type: {media_type}", - get_usage_message(FILE_EXTENSIONS), - ) - return cast(ImageMediaType, media_type) +def _create_text_block(text: str) -> TextBlockParam: + return TextBlockParam(text=text, type="text") -def _create_image_block( - media_type: ImageMediaType, data: str -) -> ImageBlockParam: +def _create_image_block(resource: Resource) -> ImageBlockParam: return ImageBlockParam( source=Source( - data=data, - media_type=media_type, + data=resource.data_base64, + media_type=cast(ImageMediaType, resource.type), type="base64", ), type="image", ) -async def _download_data(url: str, file_storage: Optional[FileStorage]) -> str: - if not file_storage: - return await download_file_as_base64(url) - - return await file_storage.download_file_as_base64(url) - - -async def _to_claude_image( - attachment: Attachment, file_storage: Optional[FileStorage] +async def _collect_image_block( + file_storage: FileStorage | None, dial_resource: DialResource ) -> ImageBlockParam: - if attachment.data: - if not attachment.type: - raise ValidationError( - "Attachment type is required for provided data" - ) - return _create_image_block( - _validate_media_type(attachment.type), attachment.data + try: + resource = await dial_resource.download(file_storage) + except UnsupportedContentType as e: + raise UserError( + f"Unsupported media type: {e.type}", + get_usage_message(FILE_EXTENSIONS), ) - if attachment.url: - media_type = attachment.type or mimetypes.guess_type(attachment.url)[0] - if not media_type: - raise ValidationError( - f"Cannot guess attachment type for {attachment.url}" - ) - - data = await _download_data(attachment.url, file_storage) - return _create_image_block(_validate_media_type(media_type), data) - - raise ValidationError("Attachment data or URL is required") + return _create_image_block(resource) async def _to_claude_message( + file_storage: FileStorage | None, message: AIRegularMessage | HumanRegularMessage, - file_storage: Optional[FileStorage], ) -> List[TextBlockParam | ImageBlockParam]: - content: List[TextBlockParam | ImageBlockParam] = [] + ret: List[TextBlockParam | ImageBlockParam] = [] - if message.custom_content: - for attachment in message.custom_content.attachments or []: - content.append(await _to_claude_image(attachment, file_storage)) + for attachment in message.attachments: + dial_resource = AttachmentResource( + attachment=attachment, + entity_name="image attachment", + supported_types=IMAGE_MEDIA_TYPES, + ) + ret.append(await _collect_image_block(file_storage, dial_resource)) + + content = message.content + + match content: + case str(): + ret.append(_create_text_block(content)) + case list(): + for part in content: + match part: + case MessageContentTextPart(text=text): + ret.append(_create_text_block(text)) + case MessageContentImagePart(image_url=image_url): + dial_resource = URLResource( + url=image_url.url, + entity_name="image url", + supported_types=IMAGE_MEDIA_TYPES, + ) + ret.append( + await _collect_image_block( + file_storage, dial_resource + ) + ) + case _: + assert_never(part) + case _: + assert_never(content) - content.append(TextBlockParam(text=message.content, type="text")) - return content + return ret def _to_claude_tool_call(call: ToolCall) -> ToolUseBlockParam: @@ -130,7 +138,7 @@ def _to_claude_tool_result( return ToolResultBlockParam( tool_use_id=message.id, type="tool_result", - content=[TextBlockParam(text=message.content, type="text")], + content=[_create_text_block(message.content)], ) @@ -143,7 +151,7 @@ async def to_claude_messages( system_prompt: str | None = None if isinstance(messages[0], SystemMessage): - system_prompt = messages[0].content + system_prompt = messages[0].text_content messages = messages[1:] claude_messages: List[MessageParam] = [] @@ -153,14 +161,14 @@ async def to_claude_messages( claude_messages.append( MessageParam( role="user", - content=await _to_claude_message(message, file_storage), + content=await _to_claude_message(file_storage, message), ) ) case AIRegularMessage(): claude_messages.append( MessageParam( role="assistant", - content=await _to_claude_message(message, file_storage), + content=await _to_claude_message(file_storage, message), ) ) case AIToolCallMessage(): @@ -168,9 +176,7 @@ async def to_claude_messages( _to_claude_tool_call(call) for call in message.calls ] if message.content is not None: - content.insert( - 0, TextBlockParam(text=message.content, type="text") - ) + content.insert(0, _create_text_block(message.content)) claude_messages.append( MessageParam( @@ -220,7 +226,7 @@ def to_dial_finish_reason( "A model has called a tool, but no tools were given to the model in the first place." ) case _: - raise Exception(f"Unknown {tools_mode} during tool use!") + assert_never(tools_mode) case _: assert_never(finish_reason) diff --git a/aidial_adapter_bedrock/llm/model/claude/v3/tools.py b/aidial_adapter_bedrock/llm/model/claude/v3/tools.py index db048a0..7b8d18b 100644 --- a/aidial_adapter_bedrock/llm/model/claude/v3/tools.py +++ b/aidial_adapter_bedrock/llm/model/claude/v3/tools.py @@ -8,11 +8,13 @@ from aidial_adapter_bedrock.llm.errors import ValidationError from aidial_adapter_bedrock.llm.message import ( AIFunctionCallMessage, + AIRegularMessage, AIToolCallMessage, BaseMessage, HumanFunctionResultMessage, HumanRegularMessage, HumanToolResultMessage, + SystemMessage, ToolMessage, ) from aidial_adapter_bedrock.llm.tools.tools_config import ToolsMode @@ -44,7 +46,7 @@ def process_tools_block( "A model has called a tool, but no tools were given to the model in the first place." ) case _: - raise Exception(f"Unknown {tools_mode} during tool use!") + assert_never(tools_mode) def process_with_tools( @@ -64,8 +66,8 @@ def process_with_tools( ) return message elif tools_mode == ToolsMode.TOOLS: - if isinstance(message, HumanFunctionResultMessage) or isinstance( - message, AIFunctionCallMessage + if isinstance( + message, (HumanFunctionResultMessage, AIFunctionCallMessage) ): raise ValidationError( "You cannot use function messages with tools config." @@ -73,7 +75,7 @@ def process_with_tools( return message elif tools_mode == ToolsMode.FUNCTIONS: match message: - case HumanRegularMessage(): + case SystemMessage() | HumanRegularMessage() | AIRegularMessage(): return message case HumanToolResultMessage() | AIToolCallMessage(): raise ValidationError( @@ -96,7 +98,7 @@ def process_with_tools( id=message.name, content=message.content ) case _: - raise ValueError(f"Unknown message type {type(message)}") + assert_never(message) else: assert_never(tools_mode) diff --git a/aidial_adapter_bedrock/llm/model/llama/v2.py b/aidial_adapter_bedrock/llm/model/llama/v2.py index 806c1a6..66f5e55 100644 --- a/aidial_adapter_bedrock/llm/model/llama/v2.py +++ b/aidial_adapter_bedrock/llm/model/llama/v2.py @@ -47,7 +47,7 @@ def prepend_to_first_human_message(self, text: str) -> None: def validate_chat(messages: List[BaseMessage]) -> Dialogue: system: Optional[str] = None if messages and isinstance(messages[0], SystemMessage): - system = messages[0].content + system = messages[0].text_content if system.strip() == "": system = None messages = messages[1:] @@ -66,7 +66,7 @@ def validate_chat(messages: List[BaseMessage]) -> Dialogue: ) turns = [ - (human.content, assistant.content) + (human.text_content, assistant.text_content) for human, assistant in zip(human, ai) ] @@ -78,7 +78,7 @@ def validate_chat(messages: List[BaseMessage]) -> Dialogue: return Dialogue( system=system, turns=turns, - human=last_query.content, + human=last_query.text_content, ) diff --git a/aidial_adapter_bedrock/llm/model/llama/v3.py b/aidial_adapter_bedrock/llm/model/llama/v3.py index 52fa1bd..6ec4457 100644 --- a/aidial_adapter_bedrock/llm/model/llama/v3.py +++ b/aidial_adapter_bedrock/llm/model/llama/v3.py @@ -40,7 +40,7 @@ def encode_header(message: BaseMessage) -> str: def encode_message(message: BaseMessage) -> str: ret = encode_header(message) - ret += message.content.strip() + ret += message.text_content.strip() ret += "<|eot_id|>" return ret diff --git a/aidial_adapter_bedrock/llm/model/stability.py b/aidial_adapter_bedrock/llm/model/stability.py index 714b1c0..0e20aff 100644 --- a/aidial_adapter_bedrock/llm/model/stability.py +++ b/aidial_adapter_bedrock/llm/model/stability.py @@ -121,7 +121,7 @@ async def truncate_and_linearize_messages( raise ValidationError("List of messages must not be empty") return TextCompletionPrompt( - text=messages[-1].content, + text=messages[-1].text_content, stop_sequences=[], discarded_messages=list(range(len(messages) - 1)), ) diff --git a/aidial_adapter_bedrock/llm/tools/claude_emulator.py b/aidial_adapter_bedrock/llm/tools/claude_emulator.py index 83abf49..cb4d6dc 100644 --- a/aidial_adapter_bedrock/llm/tools/claude_emulator.py +++ b/aidial_adapter_bedrock/llm/tools/claude_emulator.py @@ -81,7 +81,7 @@ def add_tool_declarations( # Concat with the user system message if len(messages) > 0 and isinstance(messages[0], SystemMessage): - system_message += "\n" + messages[0].content + system_message += "\n" + messages[0].text_content messages = messages[1:] return [SystemMessage(content=system_message), *messages] diff --git a/aidial_adapter_bedrock/utils/resource.py b/aidial_adapter_bedrock/utils/resource.py new file mode 100644 index 0000000..caef5a2 --- /dev/null +++ b/aidial_adapter_bedrock/utils/resource.py @@ -0,0 +1,54 @@ +import base64 +import re +from typing import Optional + +from pydantic import BaseModel + + +class Resource(BaseModel): + type: str + data: bytes + + @classmethod + def from_base64(cls, type: str, data_base64: str) -> "Resource": + try: + data = base64.b64decode(data_base64, validate=True) + except Exception: + raise ValueError("Invalid base64 data") + + return cls(type=type, data=data) + + @classmethod + def from_data_url(cls, data_url: str) -> Optional["Resource"]: + """ + Parsing a resource encoded as a data URL. + See https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs for reference. + """ + + type = cls.parse_data_url_content_type(data_url) + if type is None: + return None + + data_base64 = data_url.removeprefix(cls._to_data_url_prefix(type)) + + return cls.from_base64(type, data_base64) + + @property + def data_base64(self) -> str: + return base64.b64encode(self.data).decode() + + def to_data_url(self) -> str: + return f"{self._to_data_url_prefix(self.type)}{self.data_base64}" + + @staticmethod + def parse_data_url_content_type(data_url: str) -> Optional[str]: + pattern = r"^data:([^;]+);base64," + match = re.match(pattern, data_url) + return None if match is None else match.group(1) + + @staticmethod + def _to_data_url_prefix(content_type: str) -> str: + return f"data:{content_type};base64," + + def __str__(self) -> str: + return self.to_data_url()[:100] + "..." diff --git a/aidial_adapter_bedrock/utils/text.py b/aidial_adapter_bedrock/utils/text.py new file mode 100644 index 0000000..ad665e3 --- /dev/null +++ b/aidial_adapter_bedrock/utils/text.py @@ -0,0 +1,4 @@ +def truncate_string(s: str, n: int) -> str: + if len(s) <= n: + return s + return s[:n] + "..." diff --git a/poetry.lock b/poetry.lock index 693fa47..926865b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,13 +2,13 @@ [[package]] name = "aidial-sdk" -version = "0.13.0" +version = "0.14.0" description = "Framework to create applications and model adapters for AI DIAL" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "aidial_sdk-0.13.0-py3-none-any.whl", hash = "sha256:35784f12367e43f4540d67bab7b18315832e313517e02e969068d7ff2de3d69e"}, - {file = "aidial_sdk-0.13.0.tar.gz", hash = "sha256:c895c22d95d1c1954e170ebda3f5010e80cd47ed8b7225d375d1da01f67962e5"}, + {file = "aidial_sdk-0.14.0-py3-none-any.whl", hash = "sha256:b3974855104c589033cea2581c8e02b650205d094673e7dadb6f6101ad8e6f38"}, + {file = "aidial_sdk-0.14.0.tar.gz", hash = "sha256:3442934a35b3bd0c7495f79717fb36ff949366ab0ec7406595ad2d5c8d25864c"}, ] [package.dependencies] @@ -2519,4 +2519,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it [metadata] lock-version = "2.0" python-versions = "^3.11,<4.0" -content-hash = "86f7b681e605415dafcbf3a68a18388e294a94019089aaaa752eaea6f79ea6c1" +content-hash = "f0e361bb3f94b328e91a84354b299679d5ff9a269f01d368cc85f485b4a808c2" diff --git a/pyproject.toml b/pyproject.toml index 6e72dde..68b51ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ repository = "https://github.com/epam/ai-dial-adapter-bedrock/" python = "^3.11,<4.0" boto3 = "1.28.57" botocore = "1.31.57" -aidial-sdk = {version = "0.13.0", extras = ["telemetry"]} +aidial-sdk = {version = "0.14.0", extras = ["telemetry"]} anthropic = {version = "0.28.1", extras = ["bedrock"]} fastapi = "0.109.2" openai = "1.13.3" diff --git a/tests/integration_tests/test_chat_completion.py b/tests/integration_tests/test_chat_completion.py index 7288d32..758d515 100644 --- a/tests/integration_tests/test_chat_completion.py +++ b/tests/integration_tests/test_chat_completion.py @@ -15,6 +15,7 @@ from pydantic import BaseModel from aidial_adapter_bedrock.deployments import ChatCompletionDeployment +from aidial_adapter_bedrock.utils.resource import Resource from tests.conftest import TEST_SERVER_URL from tests.utils.json import match_objects from tests.utils.openai import ( @@ -33,6 +34,9 @@ tool_request, tool_response, user, + user_with_attachment_data, + user_with_attachment_url, + user_with_image_url, ) @@ -116,10 +120,8 @@ def get_id(self): ChatCompletionDeployment.ANTHROPIC_CLAUDE_V2_1, ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_SONNET, ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_SONNET_US, - ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_SONNET_EU, ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET, ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_US, - ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_EU, ChatCompletionDeployment.META_LLAMA2_13B_CHAT_V1, ChatCompletionDeployment.META_LLAMA2_70B_CHAT_V1, ChatCompletionDeployment.META_LLAMA3_8B_INSTRUCT_V1, @@ -168,6 +170,16 @@ def is_claude3(deployment: ChatCompletionDeployment) -> bool: ] +def is_vision_model(deployment: ChatCompletionDeployment) -> bool: + return is_claude3(deployment) + + +blue_pic = Resource.from_base64( + type="image/png", + data_base64="iVBORw0KGgoAAAANSUhEUgAAAAMAAAADCAIAAADZSiLoAAAAF0lEQVR4nGNkYPjPwMDAwMDAxAADCBYAG10BBdmz9y8AAAAASUVORK5CYII=", +) + + def get_test_cases( deployment: ChatCompletionDeployment, streaming: bool ) -> List[TestCase]: @@ -287,6 +299,22 @@ def test_case( ), ) + if is_vision_model(deployment): + content = "describe the image" + for idx, user_message in enumerate( + [ + user_with_attachment_data(content, blue_pic), + user_with_attachment_url(content, blue_pic), + user_with_image_url(content, blue_pic), + ] + ): + test_case( + name=f"describe image {idx}", + max_tokens=100, + messages=[sys("be a helpful assistant"), user_message], # type: ignore + expected=lambda s: "blue" in s.content.lower(), + ) + test_case( name="pinocchio in one token", max_tokens=1, @@ -326,6 +354,12 @@ def test_case( if supports_tools(deployment): query = "What's the temperature in Glasgow in celsius?" + chat_history = [ + sys("act as a helpful assistant"), + user("2+3=?"), + ai("5"), + user(query), + ] function_args_checker = { "location": lambda s: "glasgow" in s.lower(), @@ -339,7 +373,7 @@ def test_case( # Functions test_case( name="weather function", - messages=[user(query)], + messages=chat_history, functions=[GET_WEATHER_FUNCTION], expected=lambda s: is_valid_function_call( s.function_call, name, function_args_checker @@ -351,7 +385,7 @@ def test_case( test_case( name="weather function followup", - messages=[user(query), function_req, function_resp], + messages=[*chat_history, function_req, function_resp], functions=[GET_WEATHER_FUNCTION], expected=lambda s: "15" in s.content.lower(), ) @@ -360,7 +394,7 @@ def test_case( tool_call_id = f"{name}_1" test_case( name="weather tool", - messages=[user(query)], + messages=chat_history, tools=[GET_WEATHER_TOOL], expected=lambda s: is_valid_tool_calls( s.tool_calls, tool_call_id, name, function_args_checker @@ -372,7 +406,7 @@ def test_case( test_case( name="weather tool followup", - messages=[user(query), tool_req, tool_resp], + messages=[*chat_history, tool_req, tool_resp], tools=[GET_WEATHER_TOOL], expected=lambda s: "15" in s.content.lower(), ) diff --git a/tests/unit_tests/chat_emulation/test_llama2_chat.py b/tests/unit_tests/chat_emulation/test_llama2_chat.py index 0b2e2d5..a0010a8 100644 --- a/tests/unit_tests/chat_emulation/test_llama2_chat.py +++ b/tests/unit_tests/chat_emulation/test_llama2_chat.py @@ -23,7 +23,7 @@ async def truncate_prompt_by_words( model_limit: Optional[int] = None, ) -> DiscardedMessages | TruncatePromptError: async def _tokenize_by_words(messages: List[BaseMessage]) -> int: - return sum(len(msg.content.split()) for msg in messages) + return sum(len(msg.text_content.split()) for msg in messages) return await compute_discarded_messages( messages=messages, diff --git a/tests/unit_tests/test_truncate_prompt.py b/tests/unit_tests/test_truncate_prompt.py index 907271e..28221f9 100644 --- a/tests/unit_tests/test_truncate_prompt.py +++ b/tests/unit_tests/test_truncate_prompt.py @@ -21,7 +21,7 @@ async def truncate_prompt_by_words( model_limit: Optional[int] = None, ) -> DiscardedMessages | TruncatePromptError: async def _tokenize_by_words(messages: List[BaseMessage]) -> int: - return sum(len(msg.content.split()) for msg in messages) + return sum(len(msg.text_content.split()) for msg in messages) return await compute_discarded_messages( messages=messages, diff --git a/tests/utils/messages.py b/tests/utils/messages.py index 2281396..d204ba5 100644 --- a/tests/utils/messages.py +++ b/tests/utils/messages.py @@ -1,13 +1,13 @@ from typing import List -from aidial_sdk.chat_completion import Attachment, CustomContent, Message +from aidial_sdk.chat_completion import Attachment, CustomContent +from aidial_sdk.chat_completion import Message as DialMessage from aidial_adapter_bedrock.llm.message import ( AIRegularMessage, - BaseMessage, HumanRegularMessage, + MessageABC, SystemMessage, - ToolMessage, ) @@ -30,5 +30,5 @@ def user_with_image(content: str, image_base64: str) -> HumanRegularMessage: return HumanRegularMessage(content=content, custom_content=custom_content) -def to_sdk_messages(messages: List[BaseMessage | ToolMessage]) -> List[Message]: +def to_sdk_messages(messages: List[MessageABC]) -> List[DialMessage]: return [msg.to_message() for msg in messages] diff --git a/tests/utils/openai.py b/tests/utils/openai.py index f962fd8..dfec0c7 100644 --- a/tests/utils/openai.py +++ b/tests/utils/openai.py @@ -18,6 +18,9 @@ ChatCompletionToolParam, ChatCompletionUserMessageParam, ) +from openai.types.chat.chat_completion_content_part_param import ( + ChatCompletionContentPartParam, +) from openai.types.chat.chat_completion_message import ( ChatCompletionMessage, FunctionCall, @@ -29,6 +32,7 @@ from openai.types.shared_params.function_definition import FunctionDefinition from pydantic import BaseModel +from aidial_adapter_bedrock.utils.resource import Resource from tests.conftest import DEFAULT_API_VERSION @@ -52,10 +56,58 @@ def ai_tools( return {"role": "assistant", "tool_calls": tool_calls} -def user(content: str) -> ChatCompletionUserMessageParam: +def user( + content: str | List[ChatCompletionContentPartParam], +) -> ChatCompletionUserMessageParam: return {"role": "user", "content": content} +def user_with_attachment_data( + content: str, resource: Resource +) -> ChatCompletionUserMessageParam: + return { + "role": "user", + "content": content, + "custom_content": { # type: ignore + "attachments": [ + {"type": resource.type, "data": resource.data_base64} + ] + }, + } + + +def user_with_attachment_url( + content: str, resource: Resource +) -> ChatCompletionUserMessageParam: + return { + "role": "user", + "content": content, + "custom_content": { # type: ignore + "attachments": [ + { + "type": resource.type, + "url": resource.to_data_url(), + } + ] + }, + } + + +def user_with_image_url( + content: str, image: Resource +) -> ChatCompletionUserMessageParam: + return { + "role": "user", + "content": [ + {"type": "text", "text": content}, + { + "type": "image_url", + "image_url": {"url": image.to_data_url()}, + }, + ], + } + + def function_request(name: str, args: Any) -> ToolFunction: return {"name": name, "arguments": json.dumps(args)}