diff --git a/README.md b/README.md index 46653802..8a1de783 100644 --- a/README.md +++ b/README.md @@ -20,13 +20,15 @@ Note that a model supports `/truncate_prompt` endpoint if and only if it support |Anthropic|Claude 2.1|anthropic.claude-v2:1|text-to-text|✅|✅|✅| |Anthropic|Claude 2|anthropic.claude-v2|text-to-text|✅|✅|❌| |Anthropic|Claude Instant 1.2|anthropic.claude-instant-v1|text-to-text|🟡|🟡|❌| -|Meta|Llama 3.1 405B Instruct|meta.llama3-1-405b-instruct-v1:0|text-to-text|🟡|🟡|❌| -|Meta|Llama 3.1 70B Instruct|meta.llama3-1-70b-instruct-v1:0|text-to-text|🟡|🟡|❌| +|Meta|Llama 3.2 90B Instruct|us.meta.llama3-2-90b-instruct-v1:0|text-to-text, image-to-text|🟡|🟡|✅| +|Meta|Llama 3.2 11B Instruct|us.meta.llama3-2-11b-instruct-v1:0|text-to-text, image-to-text|🟡|🟡|❌| +|Meta|Llama 3.2 3B Instruct|us.meta.llama3-2-3b-instruct-v1:0|text-to-text|🟡|🟡|❌| +|Meta|Llama 3.2 1B Instruct|us.meta.llama3-2-1b-instruct-v1:0|text-to-text|🟡|🟡|❌| +|Meta|Llama 3.1 405B Instruct|meta.llama3-1-405b-instruct-v1:0|text-to-text|🟡|🟡|✅| +|Meta|Llama 3.1 70B Instruct|meta.llama3-1-70b-instruct-v1:0|text-to-text|🟡|🟡|✅| |Meta|Llama 3.1 8B Instruct|meta.llama3-1-8b-instruct-v1:0|text-to-text|🟡|🟡|❌| |Meta|Llama 3 Chat 70B Instruct|meta.llama3-70b-instruct-v1:0|text-to-text|🟡|🟡|❌| |Meta|Llama 3 Chat 8B Instruct|meta.llama3-8b-instruct-v1:0|text-to-text|🟡|🟡|❌| -|Meta|Llama 2 Chat 70B|meta.llama2-70b-chat-v1|text-to-text|🟡|🟡|❌| -|Meta|Llama 2 Chat 13B|meta.llama2-13b-chat-v1|text-to-text|🟡|🟡|❌| |Stability AI|SDXL 1.0|stability.stable-diffusion-xl-v1|text-to-image|❌|🟡|❌| |Stability AI|SD3 Large 1.0|stability.sd3-large-v1:0|text-to-image / image-to-image|❌|🟡|❌| |Stability AI|Stable Image Ultra 1.0|stability.stable-image-ultra-v1:0|text-to-image|❌|🟡|❌| diff --git a/aidial_adapter_bedrock/bedrock.py b/aidial_adapter_bedrock/bedrock.py index 21fe0fa3..dd919a3a 100644 --- a/aidial_adapter_bedrock/bedrock.py +++ b/aidial_adapter_bedrock/bedrock.py @@ -1,7 +1,7 @@ import json from abc import ABC from logging import DEBUG -from typing import Any, AsyncIterator, Mapping, Optional, Tuple +from typing import Any, AsyncIterator, Mapping, Optional, Tuple, Unpack import boto3 from botocore.eventstream import EventStream @@ -10,6 +10,7 @@ from aidial_adapter_bedrock.aws_client_config import AWSClientConfig from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage +from aidial_adapter_bedrock.llm.converse.types import ConverseRequest from aidial_adapter_bedrock.utils.concurrency import ( make_async, to_async_iterator, @@ -36,6 +37,23 @@ async def acreate(cls, aws_client_config: AWSClientConfig) -> "Bedrock": ) return cls(client) + async def aconverse_non_streaming( + self, model: str, **params: Unpack[ConverseRequest] + ): + response = await make_async( + lambda: self.client.converse(modelId=model, **params) + ) + return response + + async def aconverse_streaming( + self, model: str, **params: Unpack[ConverseRequest] + ): + response = await make_async( + lambda: self.client.converse_stream(modelId=model, **params) + ) + + return to_async_iterator(iter(response["stream"])) + def _create_invoke_params(self, model: str, body: dict) -> dict: return { "modelId": model, diff --git a/aidial_adapter_bedrock/deployments.py b/aidial_adapter_bedrock/deployments.py index e0e96a87..8758708a 100644 --- a/aidial_adapter_bedrock/deployments.py +++ b/aidial_adapter_bedrock/deployments.py @@ -43,13 +43,15 @@ class ChatCompletionDeployment(str, Enum): STABILITY_STABLE_DIFFUSION_3_LARGE_V1 = "stability.sd3-large-v1:0" STABILITY_STABLE_IMAGE_ULTRA_V1 = "stability.stable-image-ultra-v1:0" - META_LLAMA2_13B_CHAT_V1 = "meta.llama2-13b-chat-v1" - META_LLAMA2_70B_CHAT_V1 = "meta.llama2-70b-chat-v1" META_LLAMA3_8B_INSTRUCT_V1 = "meta.llama3-8b-instruct-v1:0" META_LLAMA3_70B_INSTRUCT_V1 = "meta.llama3-70b-instruct-v1:0" - META_LLAMA3_1_405B_INSTRUCT_V1 = "meta.llama3-1-405b-instruct-v1:0" - META_LLAMA3_1_70B_INSTRUCT_V1 = "meta.llama3-1-70b-instruct-v1:0" META_LLAMA3_1_8B_INSTRUCT_V1 = "meta.llama3-1-8b-instruct-v1:0" + META_LLAMA3_1_70B_INSTRUCT_V1 = "meta.llama3-1-70b-instruct-v1:0" + META_LLAMA3_1_405B_INSTRUCT_V1 = "meta.llama3-1-405b-instruct-v1:0" + META_LLAMA3_2_1B_INSTRUCT_V1 = "us.meta.llama3-2-1b-instruct-v1:0" + META_LLAMA3_2_3B_INSTRUCT_V1 = "us.meta.llama3-2-3b-instruct-v1:0" + META_LLAMA3_2_11B_INSTRUCT_V1 = "us.meta.llama3-2-11b-instruct-v1:0" + META_LLAMA3_2_90B_INSTRUCT_V1 = "us.meta.llama3-2-90b-instruct-v1:0" COHERE_COMMAND_TEXT_V14 = "cohere.command-text-v14" COHERE_COMMAND_LIGHT_TEXT_V14 = "cohere.command-light-text-v14" diff --git a/aidial_adapter_bedrock/llm/converse/__init__.py b/aidial_adapter_bedrock/llm/converse/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aidial_adapter_bedrock/llm/converse/adapter.py b/aidial_adapter_bedrock/llm/converse/adapter.py new file mode 100644 index 00000000..3da26bb9 --- /dev/null +++ b/aidial_adapter_bedrock/llm/converse/adapter.py @@ -0,0 +1,180 @@ +from typing import Any, Awaitable, Callable, List, Tuple + +from aidial_sdk.chat_completion import Message as DialMessage + +from aidial_adapter_bedrock.bedrock import Bedrock +from aidial_adapter_bedrock.dial_api.request import ModelParameters +from aidial_adapter_bedrock.dial_api.storage import FileStorage +from aidial_adapter_bedrock.llm.chat_model import ( + ChatCompletionAdapter, + keep_last, + turn_based_partitioner, +) +from aidial_adapter_bedrock.llm.consumer import Consumer +from aidial_adapter_bedrock.llm.converse.input import ( + extract_converse_system_prompt, + to_converse_messages, + to_converse_tools, +) +from aidial_adapter_bedrock.llm.converse.output import ( + process_non_streaming, + process_streaming, +) +from aidial_adapter_bedrock.llm.converse.types import ( + ConverseDeployment, + ConverseMessage, + ConverseRequestWrapper, + ConverseTools, + InferenceConfig, +) +from aidial_adapter_bedrock.llm.errors import ValidationError +from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string +from aidial_adapter_bedrock.llm.truncate_prompt import ( + DiscardedMessages, + truncate_prompt, +) +from aidial_adapter_bedrock.utils.json import remove_nones +from aidial_adapter_bedrock.utils.list import omit_by_indices +from aidial_adapter_bedrock.utils.list_projection import ListProjection + +ConverseMessages = List[Tuple[ConverseMessage, Any]] + + +class ConverseAdapter(ChatCompletionAdapter): + deployment: str + bedrock: Bedrock + storage: FileStorage | None + + tokenize_text: Callable[[str], int] = default_tokenize_string + input_tokenizer_factory: Callable[ + [ConverseDeployment, ConverseRequestWrapper], + Callable[[ConverseMessages], Awaitable[int]], + ] + support_tools: bool + partitioner: Callable[[ConverseMessages], List[int]] = ( + turn_based_partitioner + ) + + async def _discard_messages( + self, params: ConverseRequestWrapper, max_prompt_tokens: int | None + ) -> Tuple[DiscardedMessages | None, ConverseRequestWrapper]: + if max_prompt_tokens is None: + return None, params + + discarded_messages, messages = await truncate_prompt( + messages=params.messages.list, + tokenizer=self.input_tokenizer_factory(self.deployment, params), + keep_message=keep_last, + partitioner=self.partitioner, + model_limit=None, + user_limit=max_prompt_tokens, + ) + + return list( + params.messages.to_original_indices(discarded_messages) + ), ConverseRequestWrapper( + messages=ListProjection( + omit_by_indices(messages, discarded_messages) + ), + system=params.system, + inferenceConfig=params.inferenceConfig, + toolConfig=params.toolConfig, + ) + + async def count_prompt_tokens( + self, params: ModelParameters, messages: List[DialMessage] + ) -> int: + converse_params = await self.construct_converse_params(messages, params) + return await self.input_tokenizer_factory( + self.deployment, converse_params + )(converse_params.messages.list) + + async def count_completion_tokens(self, string: str) -> int: + return self.tokenize_text(string) + + async def compute_discarded_messages( + self, params: ModelParameters, messages: List[DialMessage] + ) -> DiscardedMessages | None: + converse_params = await self.construct_converse_params(messages, params) + discarded_messages, _ = await self._discard_messages( + converse_params, params.max_prompt_tokens + ) + return discarded_messages + + def get_tool_config(self, params: ModelParameters) -> ConverseTools | None: + if params.tool_config and not self.support_tools: + raise ValidationError("Tools are not supported") + return ( + to_converse_tools(params.tool_config) + if params.tool_config + else None + ) + + async def construct_converse_params( + self, + messages: List[DialMessage], + params: ModelParameters, + ) -> ConverseRequestWrapper: + system_prompt_extraction = extract_converse_system_prompt(messages) + converse_messages = await to_converse_messages( + system_prompt_extraction.non_system_messages, + self.storage, + start_offset=system_prompt_extraction.system_message_count, + ) + system_message = system_prompt_extraction.system_prompt + if not converse_messages.list: + raise ValidationError("List of messages must not be empty") + + return ConverseRequestWrapper( + system=[system_message] if system_message else None, + messages=converse_messages, + inferenceConfig=InferenceConfig( + **remove_nones( + { + "temperature": params.temperature, + "topP": params.top_p, + "maxTokens": params.max_tokens, + "stopSequences": params.stop, + } + ) + ), + toolConfig=self.get_tool_config(params), + ) + + def is_stream(self, params: ModelParameters) -> bool: + return params.stream + + async def chat( + self, + consumer: Consumer, + params: ModelParameters, + messages: List[DialMessage], + ) -> None: + + converse_params = await self.construct_converse_params(messages, params) + discarded_messages, converse_params = await self._discard_messages( + converse_params, params.max_prompt_tokens + ) + if not converse_params.messages.raw_list: + raise ValidationError("No messages left after truncation") + + consumer.set_discarded_messages(discarded_messages) + + if self.is_stream(params): + await process_streaming( + params=params, + stream=( + await self.bedrock.aconverse_streaming( + self.deployment, **converse_params.to_request() + ) + ), + consumer=consumer, + ) + else: + process_non_streaming( + params=params, + response=await self.bedrock.aconverse_non_streaming( + self.deployment, **converse_params.to_request() + ), + consumer=consumer, + ) diff --git a/aidial_adapter_bedrock/llm/converse/constants.py b/aidial_adapter_bedrock/llm/converse/constants.py new file mode 100644 index 00000000..562c92b6 --- /dev/null +++ b/aidial_adapter_bedrock/llm/converse/constants.py @@ -0,0 +1,12 @@ +from aidial_sdk.chat_completion import FinishReason as DialFinishReason + +from aidial_adapter_bedrock.llm.converse.types import ConverseStopReason + +CONVERSE_TO_DIAL_FINISH_REASON = { + ConverseStopReason.END_TURN: DialFinishReason.STOP, + ConverseStopReason.TOOL_USE: DialFinishReason.TOOL_CALLS, + ConverseStopReason.MAX_TOKENS: DialFinishReason.LENGTH, + ConverseStopReason.STOP_SEQUENCE: DialFinishReason.STOP, + ConverseStopReason.GUARDRAIL_INTERVENED: DialFinishReason.CONTENT_FILTER, + ConverseStopReason.CONTENT_FILTERED: DialFinishReason.CONTENT_FILTER, +} diff --git a/aidial_adapter_bedrock/llm/converse/input.py b/aidial_adapter_bedrock/llm/converse/input.py new file mode 100644 index 00000000..1df45a2d --- /dev/null +++ b/aidial_adapter_bedrock/llm/converse/input.py @@ -0,0 +1,325 @@ +import json +from dataclasses import dataclass +from typing import List, Set, Tuple, assert_never + +from aidial_sdk.chat_completion import FunctionCall as DialFunctionCall +from aidial_sdk.chat_completion import Message as DialMessage +from aidial_sdk.chat_completion import ( + MessageContentImagePart, + MessageContentTextPart, +) +from aidial_sdk.chat_completion import Role as DialRole +from aidial_sdk.chat_completion import ToolCall as DialToolCall +from aidial_sdk.exceptions import RuntimeServerError + +from aidial_adapter_bedrock.dial_api.request import ToolsConfig +from aidial_adapter_bedrock.dial_api.resource import ( + AttachmentResource, + URLResource, +) +from aidial_adapter_bedrock.dial_api.storage import FileStorage +from aidial_adapter_bedrock.llm.converse.types import ( + ConverseContentPart, + ConverseMessage, + ConverseRole, + ConverseTextPart, + ConverseToolResultPart, + ConverseTools, + ConverseToolSpec, + ConverseToolUsePart, +) +from aidial_adapter_bedrock.llm.errors import ValidationError +from aidial_adapter_bedrock.utils.list import group_by +from aidial_adapter_bedrock.utils.list_projection import ListProjection + + +def to_converse_role(role: DialRole) -> ConverseRole: + """ + Converse API accepts only 'user' and 'assistant' roles + """ + match role: + case DialRole.USER | DialRole.TOOL | DialRole.FUNCTION: + return ConverseRole.USER + case DialRole.ASSISTANT: + return ConverseRole.ASSISTANT + case DialRole.SYSTEM: + raise ValidationError("System messages are not allowed") + case _: + assert_never(role) + + +def to_converse_tools(tools_config: ToolsConfig) -> ConverseTools: + tools: list[ConverseToolSpec] = [] + for function in tools_config.functions: + tools.append( + { + "toolSpec": { + "name": function.name, + "description": function.description or "", + "inputSchema": { + "json": function.parameters + or {"type": "object", "properties": {}} + }, + } + } + ) + + return { + "tools": tools, + "toolChoice": ({"any": {}} if tools_config.required else {"auto": {}}), + } + + +def function_call_to_content_part( + dial_call: DialFunctionCall, +) -> ConverseToolUsePart: + return { + "toolUse": { + "toolUseId": dial_call.name, + "name": dial_call.name, + "input": json.loads(dial_call.arguments), + } + } + + +def tool_call_to_content_part( + dial_call: DialToolCall, +) -> ConverseToolUsePart: + return { + "toolUse": { + "toolUseId": dial_call.id, + "name": dial_call.function.name, + "input": json.loads(dial_call.function.arguments), + } + } + + +def function_result_to_content_part( + message: DialMessage, +) -> ConverseToolResultPart: + if message.role != DialRole.FUNCTION: + raise RuntimeServerError( + "Function result message is expected to have function role" + ) + if not message.name or not isinstance(message.content, str): + raise RuntimeServerError( + "Function result message is expected to have function name and plain text content" + ) + + return { + "toolResult": { + "toolUseId": message.name, + "content": [{"text": message.content}], + "status": "success", + } + } + + +def tool_result_to_content_part( + message: DialMessage, +) -> ConverseToolResultPart: + if message.role != DialRole.TOOL: + raise RuntimeServerError( + "Tool result message is expected to have tool role" + ) + if not message.tool_call_id or not isinstance(message.content, str): + raise RuntimeServerError( + "Tool result message is expected to have tool call id and plain text content" + ) + + try: + json_content = json.loads(message.content) + return { + "toolResult": { + "toolUseId": message.tool_call_id, + "content": [{"json": json_content}], + "status": "success", + } + } + except json.JSONDecodeError: + return { + "toolResult": { + "toolUseId": message.tool_call_id, + "content": [{"text": message.content}], + "status": "success", + } + } + + +def to_converse_image_type(type: str) -> str: + if type == "image/png": + return "png" + if type == "image/jpeg": + return "jpeg" + raise RuntimeServerError(f"Unsupported image type: {type}") + + +async def _get_converse_message_content( + message: DialMessage, + storage: FileStorage | None, + supported_image_types: list[str] | None = None, +) -> List[ConverseContentPart]: + + if message.role == DialRole.FUNCTION: + return [function_result_to_content_part(message)] + elif message.role == DialRole.TOOL: + return [tool_result_to_content_part(message)] + + content = [] + match message.content: + case str(): + content.append({"text": message.content}) + case list(): + for part in message.content: + match part: + case MessageContentTextPart(): + content.append({"text": part.text}) + case MessageContentImagePart(): + resource = await URLResource( + url=part.image_url.url, + supported_types=supported_image_types, + ).download(storage) + content.append( + { + "image": { + "format": to_converse_image_type( + resource.type + ), + "source": { + "bytes": resource.data, + }, + } + } + ) + case None: + pass + case _: + assert_never(message.content) + + if message.custom_content and message.custom_content.attachments: + for attachment in message.custom_content.attachments: + resource = await AttachmentResource(attachment=attachment).download( + storage + ) + content.append( + { + "image": { + "format": to_converse_image_type(resource.type), + "source": {"bytes": resource.data}, + } + } + ) + if message.function_call and message.tool_calls: + raise ValidationError( + "You cannot use both function call and tool calls in the same message" + ) + elif message.function_call: + content.append(function_call_to_content_part(message.function_call)) + elif message.tool_calls: + content.extend( + [ + tool_call_to_content_part(tool_call) + for tool_call in message.tool_calls + ] + ) + + return content + + +async def to_converse_message( + message: DialMessage, + storage: FileStorage | None, + supported_image_types: list[str] | None = None, +) -> ConverseMessage: + + return { + "role": to_converse_role(message.role), + "content": await _get_converse_message_content( + message, storage, supported_image_types + ), + } + + +@dataclass +class ExtractSystemPromptResult: + system_prompt: ConverseTextPart | None + system_message_count: int + non_system_messages: List[DialMessage] + + +def extract_converse_system_prompt( + messages: List[DialMessage], +) -> ExtractSystemPromptResult: + system_msgs = [] + found_non_system = False + system_messages_count = 0 + non_system_messages = [] + + for msg in messages: + if msg.role == DialRole.SYSTEM: + if found_non_system: + raise ValidationError( + "A system message can only follow another system message" + ) + system_messages_count += 1 + match msg.content: + case str(): + system_msgs.append(msg.content) + case list(): + for part in msg.content: + match part: + case MessageContentTextPart(): + system_msgs.append(part.text) + case MessageContentImagePart(): + raise ValidationError( + "System messages cannot contain images" + ) + case None: + pass + case _: + assert_never(msg.content) + else: + found_non_system = True + non_system_messages.append(msg) + combined = "\n\n".join(msg for msg in system_msgs if msg) + return ExtractSystemPromptResult( + system_prompt=ConverseTextPart(text=combined) if combined else None, + system_message_count=system_messages_count, + non_system_messages=non_system_messages, + ) + + +async def to_converse_messages( + messages: List[DialMessage], + storage: FileStorage | None, + # Offset for system messages at the beginning + start_offset: int = 0, +) -> ListProjection[ConverseMessage]: + def _merge( + a: Tuple[ConverseMessage, Set[int]], + b: Tuple[ConverseMessage, Set[int]], + ) -> Tuple[ConverseMessage, Set[int]]: + (msg1, set1), (msg2, set2) = a, b + + content1 = msg1["content"] + content2 = msg2["content"] + + return { + "role": msg1["role"], + "content": content1 + content2, + }, set1 | set2 + + converted = [ + (await to_converse_message(msg, storage), set([idx])) + for idx, msg in enumerate(messages, start=start_offset) + ] + + # Merge messages with the same roles to achieve an alternation of user-assistant roles. + return ListProjection( + group_by( + lst=converted, + key=lambda msg: msg[0]["role"], + init=lambda msg: msg, + merge=_merge, + ) + ) diff --git a/aidial_adapter_bedrock/llm/converse/output.py b/aidial_adapter_bedrock/llm/converse/output.py new file mode 100644 index 00000000..83d07813 --- /dev/null +++ b/aidial_adapter_bedrock/llm/converse/output.py @@ -0,0 +1,147 @@ +import json +from typing import Any, AsyncIterator, Dict, assert_never + +from aidial_sdk.chat_completion import FinishReason as DialFinishReason +from aidial_sdk.chat_completion import FunctionCall as DialFunctionCall +from aidial_sdk.chat_completion import ToolCall as DialToolCall +from aidial_sdk.exceptions import RuntimeServerError + +from aidial_adapter_bedrock.dial_api.request import ModelParameters +from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage +from aidial_adapter_bedrock.llm.consumer import Consumer +from aidial_adapter_bedrock.llm.converse.constants import ( + CONVERSE_TO_DIAL_FINISH_REASON, +) +from aidial_adapter_bedrock.llm.converse.types import ConverseStopReason +from aidial_adapter_bedrock.llm.tools.tools_config import ToolsMode + + +def to_dial_finish_reason( + converse_stop_reason: ConverseStopReason, +) -> DialFinishReason: + if converse_stop_reason not in CONVERSE_TO_DIAL_FINISH_REASON.keys(): + raise RuntimeServerError( + f"Unsupported converse stop reason: {converse_stop_reason}" + ) + return CONVERSE_TO_DIAL_FINISH_REASON[converse_stop_reason] + + +async def process_streaming( + params: ModelParameters, + stream: AsyncIterator[Any], + consumer: Consumer, +) -> None: + current_tool_use = None + + async for event in stream: + if (content_block_start := event.get("contentBlockStart")) and ( + tool_use := content_block_start.get("start", {}).get("toolUse") + ): + if current_tool_use is not None: + raise ValueError("Tool use already started") + current_tool_use = {"input": ""} | tool_use + + elif content_block := event.get("contentBlockDelta"): + delta = content_block.get("delta", {}) + + if message := delta.get("text"): + consumer.append_content(message) + + if "toolUse" in delta: + if current_tool_use is None: + raise ValueError("Received tool delta before start block") + else: + current_tool_use["input"] += delta["toolUse"].get( + "input", "" + ) + + elif event.get("contentBlockStop"): + if current_tool_use: + + match params.tools_mode: + case ToolsMode.TOOLS: + consumer.create_function_tool_call( + tool_call=DialToolCall( + type="function", + id=current_tool_use["toolUseId"], + index=None, + function=DialFunctionCall( + name=current_tool_use["name"], + arguments=current_tool_use["input"], + ), + ) + ) + case ToolsMode.FUNCTIONS: + # ignoring multiple function calls in one response + if not consumer.has_function_call: + consumer.create_function_call( + function_call=DialFunctionCall( + name=current_tool_use["name"], + arguments=current_tool_use["input"], + ) + ) + case None: + raise RuntimeError( + "Tool use received without tools mode" + ) + case _: + assert_never(params.tools_mode) + current_tool_use = None + + elif (message_stop := event.get("messageStop")) and ( + stop_reason := message_stop.get("stopReason") + ): + consumer.close_content(to_dial_finish_reason(stop_reason)) + + +def process_non_streaming( + params: ModelParameters, + response: Dict[str, Any], + consumer: Consumer, +) -> None: + message = response["output"]["message"] + for content_block in message.get("content", []): + if "text" in content_block: + consumer.append_content(content_block["text"]) + if "toolUse" in content_block: + match params.tools_mode: + case ToolsMode.TOOLS: + consumer.create_function_tool_call( + tool_call=DialToolCall( + type="function", + id=content_block["toolUse"]["toolUseId"], + index=None, + function=DialFunctionCall( + name=content_block["toolUse"]["name"], + arguments=json.dumps( + content_block["toolUse"]["input"] + ), + ), + ) + ) + case ToolsMode.FUNCTIONS: + # ignoring multiple function calls in one response + if not consumer.has_function_call: + consumer.create_function_call( + function_call=DialFunctionCall( + name=content_block["toolUse"]["name"], + arguments=json.dumps( + content_block["toolUse"]["input"] + ), + ) + ) + case None: + raise RuntimeError("Tool use received without tools mode") + case _: + assert_never(params.tools_mode) + + if usage := response.get("usage"): + consumer.add_usage( + TokenUsage( + prompt_tokens=usage.get("inputTokens", 0), + completion_tokens=usage.get("outputTokens", 0), + ) + ) + + if stop_reason := response.get("stopReason"): + consumer.close_content(to_dial_finish_reason(stop_reason)) diff --git a/aidial_adapter_bedrock/llm/converse/types.py b/aidial_adapter_bedrock/llm/converse/types.py new file mode 100644 index 00000000..54b8e153 --- /dev/null +++ b/aidial_adapter_bedrock/llm/converse/types.py @@ -0,0 +1,137 @@ +""" +Types for Converse API: +https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html +""" + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Literal, Required, TypedDict, Union + +from aidial_adapter_bedrock.utils.json import remove_nones +from aidial_adapter_bedrock.utils.list_projection import ListProjection + + +class ConverseRole(str, Enum): + USER = "user" + ASSISTANT = "assistant" + + +class ConverseTextPart(TypedDict): + text: str + + +class ConverseJsonPart(TypedDict): + json: dict + + +class ConverseImageSource(TypedDict): + bytes: bytes + + +class ConverseImagePartConfig(TypedDict): + format: Literal["png", "jpeg", "gif", "webp"] + source: ConverseImageSource + + +class ConverseImagePart(TypedDict): + image: ConverseImagePartConfig + + +class ConverseToolUseConfig(TypedDict): + toolUseId: str + name: str + # {...}|[...]|123|123.4|'string'|True|None + input: Any + + +class ConverseToolUsePart(TypedDict): + toolUse: ConverseToolUseConfig + + +class ConverseToolResultConfig(TypedDict): + toolUseId: str + content: list[ConverseTextPart | ConverseJsonPart] + status: str + + +class ConverseToolResultPart(TypedDict): + toolResult: ConverseToolResultConfig + + +ConverseContentPart = Union[ + ConverseTextPart, + ConverseJsonPart, + ConverseImagePart, + ConverseToolUsePart, + ConverseToolResultPart, +] + + +class ConverseToolConfig(TypedDict): + name: str + description: str + inputSchema: dict + + +class ConverseToolSpec(TypedDict): + toolSpec: ConverseToolConfig + + +class ConverseTools(TypedDict): + tools: list[ConverseToolSpec] + toolChoice: dict + + +class ConverseToolUse(TypedDict): + toolUse: ConverseToolUseConfig + + +class ConverseStopReason(str, Enum): + END_TURN = "end_turn" + TOOL_USE = "tool_use" + MAX_TOKENS = "max_tokens" + STOP_SEQUENCE = "stop_sequence" + GUARDRAIL_INTERVENED = "guardrail_intervened" + CONTENT_FILTERED = "content_filtered" + + +class ConverseMessage(TypedDict): + role: ConverseRole + content: list[ConverseContentPart] + + +class InferenceConfig(TypedDict, total=False): + temperature: float + topP: float + maxTokens: int + stopSequences: list[str] + + +class ConverseRequest(TypedDict, total=False): + messages: Required[list[ConverseMessage]] + system: list[ConverseTextPart] + inferenceConfig: InferenceConfig + toolConfig: ConverseTools + + +@dataclass +class ConverseRequestWrapper: + messages: ListProjection[ConverseMessage] + system: list[ConverseTextPart] | None = None + inferenceConfig: InferenceConfig | None = None + toolConfig: ConverseTools | None = None + + def to_request(self) -> ConverseRequest: + return ConverseRequest( + messages=self.messages.raw_list, + **remove_nones( + { + "inferenceConfig": self.inferenceConfig, + "toolConfig": self.toolConfig, + "system": self.system, + } + ), + ) + + +ConverseDeployment = str diff --git a/aidial_adapter_bedrock/llm/model/adapter.py b/aidial_adapter_bedrock/llm/model/adapter.py index dcc39b7b..9ee6c50c 100644 --- a/aidial_adapter_bedrock/llm/model/adapter.py +++ b/aidial_adapter_bedrock/llm/model/adapter.py @@ -6,6 +6,7 @@ ChatCompletionDeployment, EmbeddingsDeployment, ) +from aidial_adapter_bedrock.dial_api.storage import create_file_storage from aidial_adapter_bedrock.embedding.amazon.titan_image import ( AmazonTitanImageEmbeddings, ) @@ -19,6 +20,7 @@ EmbeddingsAdapter, ) from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter +from aidial_adapter_bedrock.llm.converse.adapter import ConverseAdapter from aidial_adapter_bedrock.llm.model.ai21 import AI21Adapter from aidial_adapter_bedrock.llm.model.amazon import AmazonAdapter from aidial_adapter_bedrock.llm.model.claude.v1_v2.adapter import ( @@ -28,9 +30,12 @@ Adapter as Claude_V3, ) from aidial_adapter_bedrock.llm.model.cohere import CohereAdapter -from aidial_adapter_bedrock.llm.model.llama.v2 import llama2_config -from aidial_adapter_bedrock.llm.model.llama.v3 import llama3_config -from aidial_adapter_bedrock.llm.model.meta import MetaAdapter +from aidial_adapter_bedrock.llm.model.llama.v3 import ( + ConverseAdapterWithStreamingEmulation, +) +from aidial_adapter_bedrock.llm.model.llama.v3 import ( + input_tokenizer_factory as llama_tokenizer_factory, +) from aidial_adapter_bedrock.llm.model.stability.v1 import StabilityV1Adapter from aidial_adapter_bedrock.llm.model.stability.v2 import StabilityV2Adapter @@ -106,23 +111,31 @@ async def get_bedrock_adapter( await Bedrock.acreate(aws_client_config), model ) case ( - ChatCompletionDeployment.META_LLAMA2_13B_CHAT_V1 - | ChatCompletionDeployment.META_LLAMA2_70B_CHAT_V1 + ChatCompletionDeployment.META_LLAMA3_8B_INSTRUCT_V1 + | ChatCompletionDeployment.META_LLAMA3_70B_INSTRUCT_V1 + | ChatCompletionDeployment.META_LLAMA3_1_8B_INSTRUCT_V1 + | ChatCompletionDeployment.META_LLAMA3_2_1B_INSTRUCT_V1 + | ChatCompletionDeployment.META_LLAMA3_2_3B_INSTRUCT_V1 ): - return MetaAdapter.create( - await Bedrock.acreate(aws_client_config), model, llama2_config + return ConverseAdapter( + deployment=model, + bedrock=await Bedrock.acreate(aws_client_config), + storage=create_file_storage(api_key), + input_tokenizer_factory=llama_tokenizer_factory, + support_tools=False, ) case ( - ChatCompletionDeployment.META_LLAMA3_8B_INSTRUCT_V1 - | ChatCompletionDeployment.META_LLAMA3_70B_INSTRUCT_V1 + ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1 | ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1 - | ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1 - | ChatCompletionDeployment.META_LLAMA3_1_8B_INSTRUCT_V1 + | ChatCompletionDeployment.META_LLAMA3_2_11B_INSTRUCT_V1 + | ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1 ): - return MetaAdapter.create( - await Bedrock.acreate(aws_client_config), - model, - llama3_config, + return ConverseAdapterWithStreamingEmulation( + deployment=model, + bedrock=await Bedrock.acreate(aws_client_config), + storage=create_file_storage(api_key), + input_tokenizer_factory=llama_tokenizer_factory, + support_tools=True, ) case ( ChatCompletionDeployment.COHERE_COMMAND_TEXT_V14 diff --git a/aidial_adapter_bedrock/llm/model/llama/conf.py b/aidial_adapter_bedrock/llm/model/llama/conf.py deleted file mode 100644 index 208792e4..00000000 --- a/aidial_adapter_bedrock/llm/model/llama/conf.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Callable, List - -from pydantic import BaseModel - -from aidial_adapter_bedrock.llm.chat_emulator import ChatEmulator -from aidial_adapter_bedrock.llm.message import BaseMessage - - -class LlamaConf(BaseModel): - chat_partitioner: Callable[[List[BaseMessage]], List[int]] - chat_emulator: ChatEmulator diff --git a/aidial_adapter_bedrock/llm/model/llama/v2.py b/aidial_adapter_bedrock/llm/model/llama/v2.py deleted file mode 100644 index 66f5e55c..00000000 --- a/aidial_adapter_bedrock/llm/model/llama/v2.py +++ /dev/null @@ -1,142 +0,0 @@ -""" -Turning a chat into a prompt for the Llama2 model. - -The reference for the algo is [this code snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362) in the original repository. - -See also the [tokenizer](https://github.com/huggingface/transformers/blob/c99f25476312521d4425335f970b198da42f832d/src/transformers/models/llama/tokenization_llama.py#L415) in the transformers package. -""" - -from typing import List, Optional, Tuple - -from pydantic import BaseModel - -from aidial_adapter_bedrock.llm.chat_emulator import ChatEmulator -from aidial_adapter_bedrock.llm.errors import ValidationError -from aidial_adapter_bedrock.llm.message import ( - AIRegularMessage, - BaseMessage, - HumanRegularMessage, - SystemMessage, -) -from aidial_adapter_bedrock.llm.model.llama.conf import LlamaConf - -B_INST, E_INST = "[INST]", "[/INST]" -B_SYS, E_SYS = "<>\n", "\n<>\n\n" -BOS = "" -EOS = "" - - -class Dialogue(BaseModel): - """Valid dialog structure for LLAMA2 model: - 1. optional system message, - 2. alternating user/assistant messages, - 3. last user query""" - - system: Optional[str] - turns: List[Tuple[str, str]] - human: str - - def prepend_to_first_human_message(self, text: str) -> None: - if self.turns: - human, ai = self.turns[0] - self.turns[0] = text + human, ai - else: - self.human = text + self.human - - -def validate_chat(messages: List[BaseMessage]) -> Dialogue: - system: Optional[str] = None - if messages and isinstance(messages[0], SystemMessage): - system = messages[0].text_content - if system.strip() == "": - system = None - messages = messages[1:] - - human = messages[::2] - ai = messages[1::2] - - is_valid_alternation = all( - isinstance(msg, HumanRegularMessage) for msg in human - ) and all(isinstance(msg, AIRegularMessage) for msg in ai) - - if not is_valid_alternation: - raise ValidationError( - "The model only supports initial optional system message and" - " follow-up alternating human/assistant messages" - ) - - turns = [ - (human.text_content, assistant.text_content) - for human, assistant in zip(human, ai) - ] - - if messages and isinstance(messages[-1], HumanRegularMessage): - last_query = messages[-1] - else: - raise ValidationError("The last message must be from user") - - return Dialogue( - system=system, - turns=turns, - human=last_query.text_content, - ) - - -def format_sequence(text: str, bos: bool, eos: bool) -> str: - if bos: - text = BOS + text - if eos: - text = text + EOS - return text - - -def create_chat_prompt(dialogue: Dialogue) -> str: - system = dialogue.system - if system is not None: - dialogue.prepend_to_first_human_message(B_SYS + system + E_SYS) - - ret: List[str] = [ - format_sequence( - f"{B_INST} {human.strip()} {E_INST} {ai.strip()} ", - bos=True, - eos=True, - ) - for human, ai in dialogue.turns - ] - - ret.append( - format_sequence( - f"{B_INST} {dialogue.human.strip()} {E_INST}", - bos=True, - eos=False, - ) - ) - - return "".join(ret) - - -class LlamaChatEmulator(ChatEmulator): - def display(self, messages: List[BaseMessage]) -> Tuple[str, List[str]]: - dialogue = validate_chat(messages) - return create_chat_prompt(dialogue), [] - - def get_ai_cue(self) -> Optional[str]: - return None - - -def llama2_chat_partitioner(messages: List[BaseMessage]) -> List[int]: - dialogue = validate_chat(messages) - - ret: List[int] = [] - if dialogue.system is not None: - ret.append(1) - ret.extend([2] * len(dialogue.turns)) - ret.append(1) - - return ret - - -llama2_config = LlamaConf( - chat_partitioner=llama2_chat_partitioner, - chat_emulator=LlamaChatEmulator(), -) diff --git a/aidial_adapter_bedrock/llm/model/llama/v3.py b/aidial_adapter_bedrock/llm/model/llama/v3.py index 6ec44575..182203ea 100644 --- a/aidial_adapter_bedrock/llm/model/llama/v3.py +++ b/aidial_adapter_bedrock/llm/model/llama/v3.py @@ -1,72 +1,40 @@ -""" -Turning a chat into a prompt for the Llama3 model. +import json -See as a reference: -https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py -""" - -from typing import List, Literal, Optional, Tuple, assert_never - -from aidial_adapter_bedrock.llm.chat_emulator import ChatEmulator -from aidial_adapter_bedrock.llm.message import ( - AIRegularMessage, - BaseMessage, - HumanRegularMessage, - SystemMessage, +from aidial_adapter_bedrock.dial_api.request import ModelParameters +from aidial_adapter_bedrock.llm.converse.adapter import ( + ConverseAdapter, + ConverseMessages, ) -from aidial_adapter_bedrock.llm.model.llama.conf import LlamaConf - - -def get_role(message: BaseMessage) -> Literal["system", "user", "assistant"]: - match message: - case SystemMessage(): - return "system" - case HumanRegularMessage(): - return "user" - case AIRegularMessage(): - return "assistant" - case _: - assert_never(message) - - -def encode_header(message: BaseMessage) -> str: - ret = "" - ret += "<|start_header_id|>" - ret += get_role(message) - ret += "<|end_header_id|>" - ret += "\n\n" - return ret - - -def encode_message(message: BaseMessage) -> str: - ret = encode_header(message) - ret += message.text_content.strip() - ret += "<|eot_id|>" - return ret - - -def encode_dialog_prompt(messages: List[BaseMessage]) -> str: - ret = "" - ret += "<|begin_of_text|>" - for message in messages: - ret += encode_message(message) - ret += encode_header(AIRegularMessage(content="")) - return ret +from aidial_adapter_bedrock.llm.converse.types import ( + ConverseDeployment, + ConverseRequestWrapper, +) +from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string -class LlamaChatEmulator(ChatEmulator): - def display(self, messages: List[BaseMessage]) -> Tuple[str, List[str]]: - return encode_dialog_prompt(messages), [] +class ConverseAdapterWithStreamingEmulation(ConverseAdapter): + """ + Llama 3 models support tools only in the non-streaming mode. + So we need to run request in non-streaming mode, and then emulate streaming. + """ - def get_ai_cue(self) -> Optional[str]: - return None + def is_stream(self, params: ModelParameters) -> bool: + if self.get_tool_config(params): + return False + return params.stream -def llama3_chat_partitioner(messages: List[BaseMessage]) -> List[int]: - return [1] * len(messages) +def input_tokenizer_factory( + deployment: ConverseDeployment, params: ConverseRequestWrapper +): + tool_tokens = default_tokenize_string(json.dumps(params.toolConfig)) + system_tokens = default_tokenize_string(json.dumps(params.system)) + async def tokenizer(msg_items: ConverseMessages) -> int: + tokens = sum( + default_tokenize_string(json.dumps(msg_item[0])) + for msg_item in msg_items + ) + return tokens + tool_tokens + system_tokens -llama3_config = LlamaConf( - chat_partitioner=llama3_chat_partitioner, - chat_emulator=LlamaChatEmulator(), -) + return tokenizer diff --git a/aidial_adapter_bedrock/llm/model/meta.py b/aidial_adapter_bedrock/llm/model/meta.py deleted file mode 100644 index a0b6e868..00000000 --- a/aidial_adapter_bedrock/llm/model/meta.py +++ /dev/null @@ -1,126 +0,0 @@ -from typing import Any, AsyncIterator, Dict, List, Optional - -from aidial_sdk.chat_completion import Message -from typing_extensions import override - -from aidial_adapter_bedrock.bedrock import ( - Bedrock, - ResponseWithInvocationMetricsMixin, -) -from aidial_adapter_bedrock.dial_api.request import ModelParameters -from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage -from aidial_adapter_bedrock.llm.chat_model import PseudoChatModel -from aidial_adapter_bedrock.llm.consumer import Consumer -from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_META -from aidial_adapter_bedrock.llm.model.llama.conf import LlamaConf -from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string -from aidial_adapter_bedrock.llm.tools.default_emulator import ( - default_tools_emulator, -) - - -class MetaResponse(ResponseWithInvocationMetricsMixin): - generation: str - prompt_token_count: Optional[int] - generation_token_count: Optional[int] - stop_reason: Optional[str] - - def content(self) -> str: - return self.generation - - def usage(self) -> TokenUsage: - return TokenUsage( - prompt_tokens=self.prompt_token_count or 0, - completion_tokens=self.generation_token_count or 0, - ) - - -def convert_params(params: ModelParameters) -> Dict[str, Any]: - ret = {} - - if params.temperature is not None: - ret["temperature"] = params.temperature - - if params.top_p is not None: - ret["top_p"] = params.top_p - - if params.max_tokens is not None: - ret["max_gen_len"] = params.max_tokens - else: - # Choosing reasonable default - ret["max_gen_len"] = DEFAULT_MAX_TOKENS_META - - return ret - - -def create_request(prompt: str, params: Dict[str, Any]) -> Dict[str, Any]: - return {"prompt": prompt, **params} - - -async def chunks_to_stream( - chunks: AsyncIterator[dict], usage: TokenUsage -) -> AsyncIterator[str]: - async for chunk in chunks: - resp = MetaResponse.parse_obj(chunk) - usage.accumulate(resp.usage_by_metrics()) - yield resp.content() - - -async def response_to_stream( - response: dict, usage: TokenUsage -) -> AsyncIterator[str]: - resp = MetaResponse.parse_obj(response) - usage.accumulate(resp.usage()) - yield resp.content() - - -class MetaAdapter(PseudoChatModel): - model: str - client: Bedrock - - @classmethod - def create(cls, client: Bedrock, model: str, conf: LlamaConf): - return cls( - client=client, - model=model, - tokenize_string=default_tokenize_string, - tools_emulator=default_tools_emulator, - chat_emulator=conf.chat_emulator, - partitioner=conf.chat_partitioner, - ) - - @override - def preprocess_messages(self, messages: List[Message]) -> List[Message]: - messages = super().preprocess_messages(messages) - - # Llama behaves strangely on empty prompt: - # it generate empty string, but claims to used up all available completion tokens. - # So replace it with a single space. - for msg in messages: - msg.content = msg.content or " " - - return messages - - async def predict( - self, consumer: Consumer, params: ModelParameters, prompt: str - ): - args = create_request(prompt, convert_params(params)) - - usage = TokenUsage() - - if params.stream: - chunks = self.client.ainvoke_streaming(self.model, args) - stream = chunks_to_stream(chunks, usage) - else: - response, _headers = await self.client.ainvoke_non_streaming( - self.model, args - ) - stream = response_to_stream(response, usage) - - stream = self.post_process_stream(stream, params, self.chat_emulator) - - async for content in stream: - consumer.append_content(content) - consumer.close_content() - - consumer.add_usage(usage) diff --git a/poetry.lock b/poetry.lock index dd3bdb8e..38ad6c76 100644 --- a/poetry.lock +++ b/poetry.lock @@ -337,41 +337,41 @@ uvloop = ["uvloop (>=0.15.2)"] [[package]] name = "boto3" -version = "1.28.57" +version = "1.35.41" description = "The AWS SDK for Python" optional = false -python-versions = ">= 3.7" +python-versions = ">=3.8" files = [ - {file = "boto3-1.28.57-py3-none-any.whl", hash = "sha256:5ddf24cf52c7fb6aaa332eaa08ae8c2afc8f2d1e8860680728533dd573904e32"}, - {file = "boto3-1.28.57.tar.gz", hash = "sha256:e2d2824ba6459b330d097e94039a9c4f96ae3f4bcdc731d620589ad79dcd16d3"}, + {file = "boto3-1.35.41-py3-none-any.whl", hash = "sha256:2bf7e7f376aee52155fc4ae4487f29333a6bcdf3a05c3bc4fede10b972d951a6"}, + {file = "boto3-1.35.41.tar.gz", hash = "sha256:e74bc6d69c04ca611b7f58afe08e2ded6cb6504a4a80557b656abeefee395f88"}, ] [package.dependencies] -botocore = ">=1.31.57,<1.32.0" +botocore = ">=1.35.41,<1.36.0" jmespath = ">=0.7.1,<2.0.0" -s3transfer = ">=0.7.0,<0.8.0" +s3transfer = ">=0.10.0,<0.11.0" [package.extras] crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.31.57" +version = "1.35.41" description = "Low-level, data-driven core of boto 3." optional = false -python-versions = ">= 3.7" +python-versions = ">=3.8" files = [ - {file = "botocore-1.31.57-py3-none-any.whl", hash = "sha256:af006248276ff8e19e3ec7214478f6257035eb40aed865e405486500471ae71b"}, - {file = "botocore-1.31.57.tar.gz", hash = "sha256:301436174635bec739b225b840fc365ca00e5c1a63e5b2a19ee679d204e01b78"}, + {file = "botocore-1.35.41-py3-none-any.whl", hash = "sha256:915c4d81e3a0be3b793c1e2efdf19af1d0a9cd4a2d8de08ee18216c14d67764b"}, + {file = "botocore-1.35.41.tar.gz", hash = "sha256:8a09a32136df8768190a6c92f0240cd59c30deb99c89026563efadbbed41fa00"}, ] [package.dependencies] jmespath = ">=0.7.1,<2.0.0" python-dateutil = ">=2.1,<3.0.0" -urllib3 = ">=1.25.4,<1.27" +urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""} [package.extras] -crt = ["awscrt (==0.16.26)"] +crt = ["awscrt (==0.22.0)"] [[package]] name = "certifi" @@ -2190,20 +2190,20 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "s3transfer" -version = "0.7.0" +version = "0.10.3" description = "An Amazon S3 Transfer Manager" optional = false -python-versions = ">= 3.7" +python-versions = ">=3.8" files = [ - {file = "s3transfer-0.7.0-py3-none-any.whl", hash = "sha256:10d6923c6359175f264811ef4bf6161a3156ce8e350e705396a7557d6293c33a"}, - {file = "s3transfer-0.7.0.tar.gz", hash = "sha256:fd3889a66f5fe17299fe75b82eae6cf722554edca744ca5d5fe308b104883d2e"}, + {file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"}, + {file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"}, ] [package.dependencies] -botocore = ">=1.12.36,<2.0a.0" +botocore = ">=1.33.2,<2.0a.0" [package.extras] -crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"] +crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] [[package]] name = "setuptools" @@ -2652,4 +2652,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it [metadata] lock-version = "2.0" python-versions = "^3.11,<4.0" -content-hash = "831f19b9b372a53821d8076e6900345d47804fab89d5606f8f919b03d7ed7647" +content-hash = "a99ddf2ca2b4ce8adf619fb1b52c1ec3d62d476cad7723c362830d8fe6a22741" diff --git a/pyproject.toml b/pyproject.toml index 91051e27..e4fdf61e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,8 +15,8 @@ repository = "https://github.com/epam/ai-dial-adapter-bedrock/" [tool.poetry.dependencies] python = "^3.11,<4.0" -boto3 = "1.28.57" -botocore = "1.31.57" +boto3 = "1.35.41" +botocore = "1.35.41" aidial-sdk = {version = "0.14.0", extras = ["telemetry"]} anthropic = {version = "0.28.1", extras = ["bedrock"]} fastapi = "0.115.2" diff --git a/tests/integration_tests/constants.py b/tests/integration_tests/constants.py index 8504d015..6c15a9bc 100644 --- a/tests/integration_tests/constants.py +++ b/tests/integration_tests/constants.py @@ -1,6 +1,14 @@ +from pathlib import Path + from aidial_adapter_bedrock.utils.resource import Resource BLUE_PNG_PICTURE = Resource.from_base64( type="image/png", data_base64="iVBORw0KGgoAAAANSUhEUgAAAAMAAAADCAIAAADZSiLoAAAAF0lEQVR4nGNkYPjPwMDAwMDAxAADCBYAG10BBdmz9y8AAAAASUVORK5CYII=", ) +CURRENT_DIR = Path(__file__).parent +SAMPLE_DOG_IMAGE_PATH = CURRENT_DIR / "images" / "dog-sample-image.png" +SAMPLE_DOG_RESOURCE = Resource( + type="image/png", + data=SAMPLE_DOG_IMAGE_PATH.read_bytes(), +) diff --git a/tests/integration_tests/test_chat_completion.py b/tests/integration_tests/test_chat_completion.py index adf8cb1d..a1030a60 100644 --- a/tests/integration_tests/test_chat_completion.py +++ b/tests/integration_tests/test_chat_completion.py @@ -16,7 +16,7 @@ UpstreamConfig, ) from aidial_adapter_bedrock.deployments import ChatCompletionDeployment -from tests.integration_tests.constants import BLUE_PNG_PICTURE +from tests.integration_tests.constants import SAMPLE_DOG_RESOURCE from tests.utils.openai import ( GET_WEATHER_FUNCTION, ChatCompletionResult, @@ -70,14 +70,16 @@ class TestCase: functions: List[Function] | None tools: List[ChatCompletionToolParam] | None + temperature: float = 0.0 def get_id(self): max_tokens_str = f"maxt={self.max_tokens}" if self.max_tokens else "" stop_sequence_str = f"stop={self.stop}" if self.stop else "" n_str = f"n={self.n}" if self.n else "" + temperature_str = f"temp={self.temperature}" if self.temperature else "" return sanitize_test_name( f"{self.deployment.value} {self.streaming} {max_tokens_str} " - f"{stop_sequence_str} {n_str} {self.name}" + f"{stop_sequence_str} {n_str} {temperature_str} {self.name}" ) @@ -99,13 +101,17 @@ def get_id(self): ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_US: _WEST, ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_V2: _WEST, ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_V2_US: _WEST, - ChatCompletionDeployment.META_LLAMA2_13B_CHAT_V1: _WEST, - ChatCompletionDeployment.META_LLAMA2_70B_CHAT_V1: _WEST, ChatCompletionDeployment.META_LLAMA3_8B_INSTRUCT_V1: _WEST, ChatCompletionDeployment.META_LLAMA3_70B_INSTRUCT_V1: _WEST, - ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1: _WEST, - ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1: _WEST, ChatCompletionDeployment.META_LLAMA3_1_8B_INSTRUCT_V1: _WEST, + ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1: _WEST, + ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1: _WEST, + # Llama 3.2 1B is too unstable in responses for integration tests + # Sometimes it cannot calculate 2+2 + # ChatCompletionDeployment.META_LLAMA3_2_1B_INSTRUCT_V1: _WEST, + ChatCompletionDeployment.META_LLAMA3_2_3B_INSTRUCT_V1: _WEST, + ChatCompletionDeployment.META_LLAMA3_2_11B_INSTRUCT_V1: _WEST, + ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1: _WEST, ChatCompletionDeployment.COHERE_COMMAND_TEXT_V14: _WEST, ChatCompletionDeployment.COHERE_COMMAND_LIGHT_TEXT_V14: _WEST, } @@ -127,6 +133,9 @@ def supports_tools(deployment: ChatCompletionDeployment) -> bool: ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_HAIKU_EU, ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_OPUS, ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_OPUS_US, + ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1, + ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1, + ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1, ] @@ -134,6 +143,8 @@ def supports_parallel_tool_calls(deployment: ChatCompletionDeployment) -> bool: return deployment not in [ ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_V2, ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_V2_US, + ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1, + ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1, ] and supports_tools(deployment) @@ -141,6 +152,13 @@ def is_llama3(deployment: ChatCompletionDeployment) -> bool: return deployment in [ ChatCompletionDeployment.META_LLAMA3_8B_INSTRUCT_V1, ChatCompletionDeployment.META_LLAMA3_70B_INSTRUCT_V1, + ChatCompletionDeployment.META_LLAMA3_1_8B_INSTRUCT_V1, + ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1, + ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1, + ChatCompletionDeployment.META_LLAMA3_2_1B_INSTRUCT_V1, + ChatCompletionDeployment.META_LLAMA3_2_3B_INSTRUCT_V1, + ChatCompletionDeployment.META_LLAMA3_2_11B_INSTRUCT_V1, + ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1, ] @@ -184,7 +202,10 @@ def is_ai21(deployment: ChatCompletionDeployment) -> bool: def is_vision_model(deployment: ChatCompletionDeployment) -> bool: - return is_claude3(deployment) + return is_claude3(deployment) or deployment in [ + ChatCompletionDeployment.META_LLAMA3_2_11B_INSTRUCT_V1, + ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1, + ] def are_tools_emulated(deployment: ChatCompletionDeployment) -> bool: @@ -218,6 +239,7 @@ def test_case( stop: List[str] | None = None, functions: List[Function] | None = None, tools: List[ChatCompletionToolParam] | None = None, + temperature: float = 0.0, ) -> None: test_cases.append( TestCase( @@ -232,6 +254,7 @@ def test_case( n, functions, tools, + temperature, ) ) @@ -304,6 +327,14 @@ def dial_recall_expected(r: ChatCompletionResult): expected_empty_message_error = streaming_error( cohere_invalid_request_error ) + elif is_llama3(deployment): + expected_empty_message_error = streaming_error( + ExpectedException( + type=BadRequestError, + message="Add text to the text field, and try again.", + status_code=400, + ) + ) test_case( name="empty user message", @@ -325,6 +356,14 @@ def dial_recall_expected(r: ChatCompletionResult): expected_whitespace_message = streaming_error( cohere_invalid_request_error ) + elif is_llama3(deployment): + expected_whitespace_message = streaming_error( + ExpectedException( + type=BadRequestError, + message="Add text to the text field, and try again.", + status_code=400, + ) + ) test_case( name="single space user message", @@ -337,16 +376,16 @@ def dial_recall_expected(r: ChatCompletionResult): content = "describe the image" for idx, user_message in enumerate( [ - user_with_attachment_data(content, BLUE_PNG_PICTURE), - user_with_attachment_url(content, BLUE_PNG_PICTURE), - user_with_image_url(content, BLUE_PNG_PICTURE), + user_with_attachment_data(content, SAMPLE_DOG_RESOURCE), + user_with_attachment_url(content, SAMPLE_DOG_RESOURCE), + user_with_image_url(content, SAMPLE_DOG_RESOURCE), ] ): test_case( name=f"describe image {idx}", max_tokens=100, messages=[sys("be a helpful assistant"), user_message], # type: ignore - expected=lambda s: "blue" in s.content.lower(), + expected=lambda s: "dog" in s.content.lower(), ) test_case( @@ -370,10 +409,17 @@ def dial_recall_expected(r: ChatCompletionResult): ) if is_llama3(deployment): + test_case( - name="out of turn", + name="out_of_turn", messages=[ai("hello"), user("what's 7+5?")], - expected=lambda s: "12" in s.content.lower(), + expected=streaming_error( + ExpectedException( + type=BadRequestError, + message="A conversation must start with a user message", + status_code=400, + ) + ), ) test_case( @@ -405,11 +451,13 @@ def dial_recall_expected(r: ChatCompletionResult): query = f"What's the temperature in {' and in '.join(city_names)} in celsius?" init_messages = [ - sys("act as a helpful assistant"), user("2+3=?"), ai("5"), user(query), ] + # Llama 3 works badly with system messages along tools + if not is_llama3(deployment): + init_messages.insert(0, sys("act as a helpful assistant")) def create_fun_args(city: str): return { @@ -433,6 +481,7 @@ def check_fun_args(city: str): expected=lambda s, n=city_names[0]: is_valid_function_call( s.function_call, fun_name, check_fun_args(n) ), + temperature=1 if is_llama3(deployment) else 0.0, ) function_req = ai_function( @@ -454,6 +503,7 @@ def check_fun_args(city: str): expected=lambda s, t=city_temps[0]: s.content_contains_all( [t] ), + temperature=1 if is_llama3(deployment) else 0.0, ) else: test_case( @@ -467,6 +517,7 @@ def check_fun_args(city: str): expected=lambda s, n=city_names[1]: is_valid_function_call( s.function_call, fun_name, check_fun_args(n) ), + temperature=1 if is_llama3(deployment) else 0.0, ) # Tools @@ -501,6 +552,7 @@ def _check(id: str) -> bool: ) for idx in range(len(n)) ), + temperature=1 if is_llama3(deployment) else 0.0, ) tool_reqs = ai_tools( @@ -523,6 +575,7 @@ def _check(id: str) -> bool: messages=[*init_messages, tool_reqs, *tool_resps], tools=[tool], expected=lambda s, t=city_temps: s.content_contains_all(t), + temperature=1 if is_llama3(deployment) else 0.0, ) return test_cases @@ -562,6 +615,7 @@ async def run_chat_completion() -> ChatCompletionResult: test.n, test.functions, test.tools, + test.temperature, ) if isinstance(test.expected, ExpectedException): diff --git a/tests/integration_tests/test_stable_diffusion.py b/tests/integration_tests/test_stable_diffusion.py index dd412d3b..05e5b4ac 100644 --- a/tests/integration_tests/test_stable_diffusion.py +++ b/tests/integration_tests/test_stable_diffusion.py @@ -1,5 +1,4 @@ import base64 -from pathlib import Path from typing import Dict from unittest.mock import patch @@ -12,7 +11,10 @@ ) from aidial_adapter_bedrock.deployments import ChatCompletionDeployment from aidial_adapter_bedrock.utils.resource import Resource -from tests.integration_tests.constants import BLUE_PNG_PICTURE +from tests.integration_tests.constants import ( + BLUE_PNG_PICTURE, + SAMPLE_DOG_RESOURCE, +) from tests.utils.mock_storage import MockFileStorage from tests.utils.openai import ( user, @@ -34,13 +36,6 @@ ] VISION_MODEL = ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_US -CURRENT_DIR = Path(__file__).parent -SAMPLE_DOG_IMAGE_PATH = CURRENT_DIR / "images" / "dog-sample-image.png" -SAMPLE_DOG_RESOURCE = Resource( - type="image/png", - data=SAMPLE_DOG_IMAGE_PATH.read_bytes(), -) - def get_upstream_headers(region: str) -> Dict[str, str]: return { diff --git a/tests/unit_tests/chat_emulation/test_llama2_chat.py b/tests/unit_tests/chat_emulation/test_llama2_chat.py deleted file mode 100644 index a0010a87..00000000 --- a/tests/unit_tests/chat_emulation/test_llama2_chat.py +++ /dev/null @@ -1,168 +0,0 @@ -from typing import List, Optional - -import pytest - -from aidial_adapter_bedrock.llm.chat_model import keep_last_and_system_messages -from aidial_adapter_bedrock.llm.errors import ValidationError -from aidial_adapter_bedrock.llm.message import BaseMessage -from aidial_adapter_bedrock.llm.model.llama.v2 import llama2_config -from aidial_adapter_bedrock.llm.truncate_prompt import ( - DiscardedMessages, - TruncatePromptError, - compute_discarded_messages, -) -from tests.utils.messages import ai, sys, user - -llama2_chat_emulator = llama2_config.chat_emulator -llama2_chat_partitioner = llama2_config.chat_partitioner - - -async def truncate_prompt_by_words( - messages: List[BaseMessage], - user_limit: int, - model_limit: Optional[int] = None, -) -> DiscardedMessages | TruncatePromptError: - async def _tokenize_by_words(messages: List[BaseMessage]) -> int: - return sum(len(msg.text_content.split()) for msg in messages) - - return await compute_discarded_messages( - messages=messages, - tokenizer=_tokenize_by_words, - keep_message=keep_last_and_system_messages, - partitioner=llama2_chat_partitioner, - model_limit=model_limit, - user_limit=user_limit, - ) - - -def test_construction_single_message(): - messages: List[BaseMessage] = [ - user(" human message1 "), - ] - - text, stop_sequences = llama2_chat_emulator.display(messages) - - assert stop_sequences == [] - assert text == "[INST] human message1 [/INST]" - - -def test_construction_many_without_system(): - messages = [ - user(" human message1 "), - ai(" ai message1 "), - user(" human message2 "), - ] - - text, stop_sequences = llama2_chat_emulator.display(messages) - - assert stop_sequences == [] - assert text == "".join( - [ - "[INST] human message1 [/INST]", - " ai message1 ", - "[INST] human message2 [/INST]", - ] - ) - - -def test_construction_many_with_system(): - messages = [ - sys(" system message1 "), - user(" human message1 "), - ai(" ai message1 "), - user(" human message2 "), - ] - - text, stop_sequences = llama2_chat_emulator.display(messages) - - assert stop_sequences == [] - assert text == "".join( - [ - "[INST] <>\n system message1 \n<>\n\n human message1 [/INST]", - " ai message1 ", - "[INST] human message2 [/INST]", - ] - ) - - -def test_invalid_alternation(): - messages = [ - ai(" ai message1 "), - user(" human message1 "), - user(" human message2 "), - ] - - with pytest.raises(ValidationError) as exc_info: - llama2_chat_emulator.display(messages) - - assert exc_info.value.message == ( - "The model only supports initial optional system message and" - " follow-up alternating human/assistant messages" - ) - - -def test_invalid_last_message(): - messages = [ - user(" human message1 "), - ai(" ai message1 "), - user(" human message2 "), - ai(" ai message2 "), - ] - - with pytest.raises(ValidationError) as exc_info: - llama2_chat_emulator.display(messages) - - assert exc_info.value.message == "The last message must be from user" - - -turns_sys = [ - sys("system"), - user("hello"), - ai("hi"), - user("ping"), - ai("pong"), - user("improvise"), -] - -turns_no_sys = turns_sys[1:] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "messages, user_limit, expected", - [ - ( - turns_sys, - 1, - "The requested maximum prompt tokens is 1. " - "However, the system messages and the last user message resulted in 2 tokens. " - "Please reduce the length of the messages or increase the maximum prompt tokens.", - ), - (turns_sys, 2, [1, 2, 3, 4]), - (turns_sys, 3, [1, 2, 3, 4]), - (turns_sys, 4, [1, 2]), - (turns_sys, 5, [1, 2]), - (turns_sys, 6, []), - (turns_no_sys, 1, [0, 1, 2, 3]), - (turns_no_sys, 2, [0, 1, 2, 3]), - (turns_no_sys, 3, [0, 1]), - (turns_no_sys, 4, [0, 1]), - (turns_no_sys, 5, []), - ], -) -async def test_multi_turn_dialogue( - messages: List[BaseMessage], - user_limit: int, - expected: DiscardedMessages | str, -): - discarded_messages = await truncate_prompt_by_words( - messages=messages, user_limit=user_limit - ) - - if isinstance(expected, str): - assert ( - isinstance(discarded_messages, TruncatePromptError) - and discarded_messages.print() == expected - ) - else: - assert discarded_messages == expected diff --git a/tests/unit_tests/converse/test_converse_adapter.py b/tests/unit_tests/converse/test_converse_adapter.py new file mode 100644 index 00000000..15ce6084 --- /dev/null +++ b/tests/unit_tests/converse/test_converse_adapter.py @@ -0,0 +1,427 @@ +from dataclasses import dataclass +from typing import List + +import pytest +from aidial_sdk.chat_completion.request import ( + Function, + FunctionCall, + ImageURL, + Message, + MessageContentImagePart, + MessageContentTextPart, + Role, + ToolCall, +) + +from aidial_adapter_bedrock.aws_client_config import AWSClientConfig +from aidial_adapter_bedrock.bedrock import Bedrock +from aidial_adapter_bedrock.dial_api.request import ModelParameters +from aidial_adapter_bedrock.llm.converse.adapter import ConverseAdapter +from aidial_adapter_bedrock.llm.converse.types import ( + ConverseImagePart, + ConverseImagePartConfig, + ConverseImageSource, + ConverseMessage, + ConverseRequestWrapper, + ConverseRole, + ConverseTextPart, + ConverseToolResultPart, + ConverseToolUseConfig, + ConverseToolUsePart, + InferenceConfig, +) +from aidial_adapter_bedrock.llm.errors import ValidationError +from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig +from aidial_adapter_bedrock.utils.list_projection import ListProjection +from tests.integration_tests.constants import BLUE_PNG_PICTURE + + +async def _input_tokenizer_factory(_deployment, _params): + async def _test_tokenizer(_messages) -> int: + return 100 + + return _test_tokenizer + + +@dataclass +class ExpectedException: + type: type[Exception] + message: str + + +@dataclass +class TestCase: + __test__ = False + name: str + messages: List[Message] + params: ModelParameters + expected_output: ConverseRequestWrapper | None = None + expected_error: ExpectedException | None = None + + +default_inference_config = InferenceConfig(stopSequences=[]) +TEST_CASES = [ + TestCase( + name="plain_message", + messages=[Message(role=Role.USER, content="Hello, world!")], + params=ModelParameters(tool_config=None), + expected_output=ConverseRequestWrapper( + inferenceConfig=default_inference_config, + messages=ListProjection( + list=[ + ( + ConverseMessage( + role=ConverseRole.USER, + content=[ConverseTextPart(text="Hello, world!")], + ), + {0}, + ) + ] + ), + ), + ), + TestCase( + name="system_message", + messages=[ + Message(role=Role.SYSTEM, content="You are a helpful assistant."), + Message(role=Role.USER, content="Hello!"), + ], + params=ModelParameters(tool_config=None), + expected_output=ConverseRequestWrapper( + inferenceConfig=default_inference_config, + system=[ConverseTextPart(text="You are a helpful assistant.")], + messages=ListProjection( + list=[ + ( + ConverseMessage( + role=ConverseRole.USER, + content=[ConverseTextPart(text="Hello!")], + ), + {1}, + ) + ] + ), + ), + ), + TestCase( + name="system_message_after_user", + messages=[ + Message(role=Role.SYSTEM, content="You are a helpful assistant."), + Message(role=Role.USER, content="Hello!"), + Message(role=Role.SYSTEM, content="You are a helpful assistant."), + ], + params=ModelParameters(tool_config=None), + expected_error=ExpectedException( + type=ValidationError, + message="A system message can only follow another system message", + ), + ), + TestCase( + name="multiple_system_messages", + messages=[ + Message(role=Role.SYSTEM, content="You are a helpful assistant."), + Message(role=Role.SYSTEM, content="You are also very friendly."), + Message(role=Role.USER, content="Hello!"), + ], + params=ModelParameters(tool_config=None), + expected_output=ConverseRequestWrapper( + inferenceConfig=default_inference_config, + system=[ + ConverseTextPart( + text="You are a helpful assistant.\n\nYou are also very friendly." + ), + ], + messages=ListProjection( + list=[ + ( + ConverseMessage( + role=ConverseRole.USER, + content=[ConverseTextPart(text="Hello!")], + ), + {2}, + ) + ] + ), + ), + ), + TestCase( + name="system_message_multiple_parts", + messages=[ + Message( + role=Role.SYSTEM, + content=[ + MessageContentTextPart( + type="text", text="You are a helpful assistant." + ), + MessageContentTextPart( + type="text", text="You are also very friendly." + ), + ], + ), + Message(role=Role.USER, content="Hello!"), + ], + params=ModelParameters(tool_config=None), + expected_output=ConverseRequestWrapper( + inferenceConfig=default_inference_config, + system=[ + ConverseTextPart( + text="You are a helpful assistant.\n\nYou are also very friendly." + ), + ], + messages=ListProjection( + list=[ + ( + ConverseMessage( + role=ConverseRole.USER, + content=[ConverseTextPart(text="Hello!")], + ), + {1}, + ) + ] + ), + ), + ), + TestCase( + name="system_message_with_forbidden_image", + messages=[ + Message( + role=Role.SYSTEM, + content=[ + MessageContentTextPart( + type="text", text="You are a helpful assistant." + ), + MessageContentImagePart( + type="image_url", + image_url=ImageURL(url=BLUE_PNG_PICTURE.to_data_url()), + ), + ], + ), + Message(role=Role.USER, content="Hello!"), + ], + params=ModelParameters(tool_config=None), + expected_error=ExpectedException( + type=ValidationError, + message="System messages cannot contain images", + ), + ), + TestCase( + name="tools_convert", + messages=[ + Message(role=Role.USER, content="What's the weather?"), + Message( + role=Role.ASSISTANT, + content=None, + tool_calls=[ + ToolCall( + index=0, + id="call_123", + type="function", + function=FunctionCall( + name="get_weather", + arguments='{"location": "London"}', + ), + ) + ], + ), + Message( + role=Role.TOOL, + content='{"temperature": "20C"}', + tool_call_id="call_123", + ), + ], + params=ModelParameters( + tool_config=ToolsConfig( + functions=[ + Function( + name="get_weather", + description="Get the weather", + parameters={"type": "object", "properties": {}}, + ) + ], + required=True, + tool_ids=None, + ) + ), + expected_output=ConverseRequestWrapper( + inferenceConfig=default_inference_config, + toolConfig={ + "tools": [ + { + "toolSpec": { + "name": "get_weather", + "description": "Get the weather", + "inputSchema": { + "json": {"properties": {}, "type": "object"} + }, + } + } + ], + "toolChoice": {"any": {}}, + }, + messages=ListProjection( + list=[ + ( + ConverseMessage( + role=ConverseRole.USER, + content=[ + ConverseTextPart(text="What's the weather?") + ], + ), + {0}, + ), + ( + ConverseMessage( + role=ConverseRole.ASSISTANT, + content=[ + ConverseToolUsePart( + toolUse=ConverseToolUseConfig( + toolUseId="call_123", + name="get_weather", + input={"location": "London"}, + ) + ) + ], + ), + {1}, + ), + ( + ConverseMessage( + role=ConverseRole.USER, + content=[ + ConverseToolResultPart( + toolResult={ + "toolUseId": "call_123", + "content": [ + {"json": {"temperature": "20C"}} + ], + "status": "success", + } + ) + ], + ), + {2}, + ), + ] + ), + ), + ), + TestCase( + name="content_parts", + messages=[ + Message( + role=Role.USER, + content=[ + MessageContentTextPart(type="text", text="Hello!"), + MessageContentImagePart( + type="image_url", + image_url=ImageURL(url=BLUE_PNG_PICTURE.to_data_url()), + ), + ], + ) + ], + params=ModelParameters(tool_config=None), + expected_output=ConverseRequestWrapper( + inferenceConfig=default_inference_config, + messages=ListProjection( + list=[ + ( + ConverseMessage( + role=ConverseRole.USER, + content=[ + ConverseTextPart(text="Hello!"), + ConverseImagePart( + image=ConverseImagePartConfig( + format="png", + source=ConverseImageSource( + bytes=BLUE_PNG_PICTURE.data + ), + ) + ), + ], + ), + {0}, + ) + ] + ), + ), + ), + TestCase( + name="shrink_messages", + messages=[ + Message(role=Role.USER, content="Say hello."), + Message(role=Role.USER, content="And have a good day."), + Message( + role=Role.ASSISTANT, + content="Hello", + ), + Message( + role=Role.ASSISTANT, + content=[ + MessageContentTextPart(type="text", text="Have a nice"), + MessageContentTextPart(type="text", text="day!"), + ], + ), + ], + params=ModelParameters(temperature=10), + expected_output=ConverseRequestWrapper( + inferenceConfig=InferenceConfig(temperature=10, stopSequences=[]), + messages=ListProjection( + list=[ + ( + ConverseMessage( + role=ConverseRole.USER, + content=[ + ConverseTextPart(text="Say hello."), + ConverseTextPart(text="And have a good day."), + ], + ), + {0, 1}, + ), + ( + ConverseMessage( + role=ConverseRole.ASSISTANT, + content=[ + ConverseTextPart(text="Hello"), + ConverseTextPart(text="Have a nice"), + ConverseTextPart(text="day!"), + ], + ), + {2, 3}, + ), + ] + ), + ), + ), +] + + +@pytest.mark.parametrize( + "test_case", TEST_CASES, ids=lambda test_case: test_case.name +) +@pytest.mark.asyncio +async def test_converse_adapter( + test_case: TestCase, +): + adapter = ConverseAdapter( + deployment="test", + bedrock=await Bedrock.acreate(AWSClientConfig(region="us-east-1")), + tokenize_text=lambda x: len(x), + input_tokenizer_factory=_input_tokenizer_factory, # type: ignore + support_tools=True, + storage=None, + ) + construct_coro = adapter.construct_converse_params( + messages=test_case.messages, + params=test_case.params, + ) + + if test_case.expected_error is not None: + with pytest.raises(test_case.expected_error.type) as exc_info: + converse_request = await construct_coro + assert hasattr(exc_info.value, "message") + error_message = getattr(exc_info.value, "message") + assert isinstance(error_message, str) + assert error_message == test_case.expected_error.message + else: + converse_request = await construct_coro + assert converse_request == test_case.expected_output diff --git a/tests/unit_tests/converse/test_to_converse_message.py b/tests/unit_tests/converse/test_to_converse_message.py new file mode 100644 index 00000000..e86052a4 --- /dev/null +++ b/tests/unit_tests/converse/test_to_converse_message.py @@ -0,0 +1,184 @@ +import pytest +from aidial_sdk.chat_completion import FunctionCall +from aidial_sdk.chat_completion import Message as DialMessage +from aidial_sdk.chat_completion import Role as DialRole +from aidial_sdk.chat_completion import ToolCall + +from aidial_adapter_bedrock.llm.converse.input import to_converse_message + + +@pytest.mark.asyncio +async def test_to_converse_message_text(): + dial_message = DialMessage(role=DialRole.USER, content="Hello, world!") + converse_message = await to_converse_message(dial_message, storage=None) + + assert converse_message == { + "role": "user", + "content": [{"text": "Hello, world!"}], + } + + +@pytest.mark.asyncio +async def test_to_converse_message_assistant(): + dial_message = DialMessage(role=DialRole.ASSISTANT, content="Hello") + converse_message = await to_converse_message(dial_message, storage=None) + + assert converse_message == { + "role": "assistant", + "content": [{"text": "Hello"}], + } + + +@pytest.mark.asyncio +async def test_to_converse_message_function_call_no_content(): + dial_message = DialMessage( + role=DialRole.ASSISTANT, + function_call=FunctionCall( + name="get_weather", arguments='{"city": "Paris"}' + ), + ) + converse_message = await to_converse_message(dial_message, storage=None) + assert converse_message == { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "get_weather", + "name": "get_weather", + "input": {"city": "Paris"}, + } + }, + ], + } + + +@pytest.mark.asyncio +async def test_to_converse_message_function_call_with_content(): + dial_message = DialMessage( + role=DialRole.ASSISTANT, + content="Calling a function", + function_call=FunctionCall( + name="get_weather", arguments='{"city": "Paris"}' + ), + ) + converse_message = await to_converse_message(dial_message, storage=None) + assert converse_message == { + "role": "assistant", + "content": [ + {"text": "Calling a function"}, + { + "toolUse": { + "toolUseId": "get_weather", + "name": "get_weather", + "input": {"city": "Paris"}, + } + }, + ], + } + + +@pytest.mark.asyncio +async def test_to_converse_message_tool_call_no_content(): + dial_message = DialMessage( + role=DialRole.ASSISTANT, + tool_calls=[ + ToolCall( + index=None, + id="123", + type="function", + function=FunctionCall( + name="get_weather", arguments='{"city": "Paris"}' + ), + ) + ], + ) + converse_message = await to_converse_message(dial_message, storage=None) + assert converse_message == { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "get_weather", + "input": {"city": "Paris"}, + } + }, + ], + } + + +@pytest.mark.asyncio +async def test_to_converse_message_tool_call_with_content(): + dial_message = DialMessage( + role=DialRole.ASSISTANT, + content="Calling a function", + tool_calls=[ + ToolCall( + index=None, + id="123", + type="function", + function=FunctionCall( + name="get_weather", arguments='{"city": "Paris"}' + ), + ) + ], + ) + converse_message = await to_converse_message(dial_message, storage=None) + assert converse_message == { + "role": "assistant", + "content": [ + {"text": "Calling a function"}, + { + "toolUse": { + "toolUseId": "123", + "name": "get_weather", + "input": {"city": "Paris"}, + } + }, + ], + } + + +@pytest.mark.asyncio +async def test_to_converse_message_multiple_tool_calls(): + dial_message = DialMessage( + role=DialRole.ASSISTANT, + tool_calls=[ + ToolCall( + index=None, + id="123", + type="function", + function=FunctionCall( + name="get_weather", arguments='{"city": "Paris"}' + ), + ), + ToolCall( + index=None, + id="456", + type="function", + function=FunctionCall( + name="get_weather", arguments='{"city": "London"}' + ), + ), + ], + ) + converse_message = await to_converse_message(dial_message, storage=None) + assert converse_message == { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "get_weather", + "input": {"city": "Paris"}, + } + }, + { + "toolUse": { + "toolUseId": "456", + "name": "get_weather", + "input": {"city": "London"}, + } + }, + ], + } diff --git a/tests/unit_tests/test_endpoints.py b/tests/unit_tests/test_endpoints.py index 79afa80b..3748ae25 100644 --- a/tests/unit_tests/test_endpoints.py +++ b/tests/unit_tests/test_endpoints.py @@ -36,13 +36,15 @@ ), (ChatCompletionDeployment.STABILITY_STABLE_IMAGE_ULTRA_V1, False, True), (ChatCompletionDeployment.STABILITY_STABLE_IMAGE_CORE_V1, False, True), - (ChatCompletionDeployment.META_LLAMA2_13B_CHAT_V1, True, True), - (ChatCompletionDeployment.META_LLAMA2_70B_CHAT_V1, True, True), (ChatCompletionDeployment.META_LLAMA3_8B_INSTRUCT_V1, True, True), (ChatCompletionDeployment.META_LLAMA3_70B_INSTRUCT_V1, True, True), - (ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1, True, True), - (ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1, True, True), (ChatCompletionDeployment.META_LLAMA3_1_8B_INSTRUCT_V1, True, True), + (ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1, True, True), + (ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1, True, True), + (ChatCompletionDeployment.META_LLAMA3_2_1B_INSTRUCT_V1, True, True), + (ChatCompletionDeployment.META_LLAMA3_2_3B_INSTRUCT_V1, True, True), + (ChatCompletionDeployment.META_LLAMA3_2_11B_INSTRUCT_V1, True, True), + (ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1, True, True), (ChatCompletionDeployment.COHERE_COMMAND_TEXT_V14, True, True), (ChatCompletionDeployment.COHERE_COMMAND_LIGHT_TEXT_V14, True, True), ] diff --git a/tests/utils/openai.py b/tests/utils/openai.py index 04f8475c..81033d00 100644 --- a/tests/utils/openai.py +++ b/tests/utils/openai.py @@ -197,6 +197,7 @@ async def chat_completion( n: Optional[int], functions: List[Function] | None, tools: List[ChatCompletionToolParam] | None, + temperature: float = 0.0, ) -> ChatCompletionResult: async def get_response() -> ChatCompletion: response = await client.chat.completions.create( @@ -205,7 +206,7 @@ async def get_response() -> ChatCompletion: stream=stream, stop=stop, max_tokens=max_tokens, - temperature=0.0, + temperature=temperature, n=n, function_call="auto" if functions is not None else NOT_GIVEN, functions=functions or NOT_GIVEN,