diff --git a/README.md b/README.md
index 46653802..8a1de783 100644
--- a/README.md
+++ b/README.md
@@ -20,13 +20,15 @@ Note that a model supports `/truncate_prompt` endpoint if and only if it support
|Anthropic|Claude 2.1|anthropic.claude-v2:1|text-to-text|✅|✅|✅|
|Anthropic|Claude 2|anthropic.claude-v2|text-to-text|✅|✅|❌|
|Anthropic|Claude Instant 1.2|anthropic.claude-instant-v1|text-to-text|🟡|🟡|❌|
-|Meta|Llama 3.1 405B Instruct|meta.llama3-1-405b-instruct-v1:0|text-to-text|🟡|🟡|❌|
-|Meta|Llama 3.1 70B Instruct|meta.llama3-1-70b-instruct-v1:0|text-to-text|🟡|🟡|❌|
+|Meta|Llama 3.2 90B Instruct|us.meta.llama3-2-90b-instruct-v1:0|text-to-text, image-to-text|🟡|🟡|✅|
+|Meta|Llama 3.2 11B Instruct|us.meta.llama3-2-11b-instruct-v1:0|text-to-text, image-to-text|🟡|🟡|❌|
+|Meta|Llama 3.2 3B Instruct|us.meta.llama3-2-3b-instruct-v1:0|text-to-text|🟡|🟡|❌|
+|Meta|Llama 3.2 1B Instruct|us.meta.llama3-2-1b-instruct-v1:0|text-to-text|🟡|🟡|❌|
+|Meta|Llama 3.1 405B Instruct|meta.llama3-1-405b-instruct-v1:0|text-to-text|🟡|🟡|✅|
+|Meta|Llama 3.1 70B Instruct|meta.llama3-1-70b-instruct-v1:0|text-to-text|🟡|🟡|✅|
|Meta|Llama 3.1 8B Instruct|meta.llama3-1-8b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 3 Chat 70B Instruct|meta.llama3-70b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 3 Chat 8B Instruct|meta.llama3-8b-instruct-v1:0|text-to-text|🟡|🟡|❌|
-|Meta|Llama 2 Chat 70B|meta.llama2-70b-chat-v1|text-to-text|🟡|🟡|❌|
-|Meta|Llama 2 Chat 13B|meta.llama2-13b-chat-v1|text-to-text|🟡|🟡|❌|
|Stability AI|SDXL 1.0|stability.stable-diffusion-xl-v1|text-to-image|❌|🟡|❌|
|Stability AI|SD3 Large 1.0|stability.sd3-large-v1:0|text-to-image / image-to-image|❌|🟡|❌|
|Stability AI|Stable Image Ultra 1.0|stability.stable-image-ultra-v1:0|text-to-image|❌|🟡|❌|
diff --git a/aidial_adapter_bedrock/bedrock.py b/aidial_adapter_bedrock/bedrock.py
index 21fe0fa3..dd919a3a 100644
--- a/aidial_adapter_bedrock/bedrock.py
+++ b/aidial_adapter_bedrock/bedrock.py
@@ -1,7 +1,7 @@
import json
from abc import ABC
from logging import DEBUG
-from typing import Any, AsyncIterator, Mapping, Optional, Tuple
+from typing import Any, AsyncIterator, Mapping, Optional, Tuple, Unpack
import boto3
from botocore.eventstream import EventStream
@@ -10,6 +10,7 @@
from aidial_adapter_bedrock.aws_client_config import AWSClientConfig
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
+from aidial_adapter_bedrock.llm.converse.types import ConverseRequest
from aidial_adapter_bedrock.utils.concurrency import (
make_async,
to_async_iterator,
@@ -36,6 +37,23 @@ async def acreate(cls, aws_client_config: AWSClientConfig) -> "Bedrock":
)
return cls(client)
+ async def aconverse_non_streaming(
+ self, model: str, **params: Unpack[ConverseRequest]
+ ):
+ response = await make_async(
+ lambda: self.client.converse(modelId=model, **params)
+ )
+ return response
+
+ async def aconverse_streaming(
+ self, model: str, **params: Unpack[ConverseRequest]
+ ):
+ response = await make_async(
+ lambda: self.client.converse_stream(modelId=model, **params)
+ )
+
+ return to_async_iterator(iter(response["stream"]))
+
def _create_invoke_params(self, model: str, body: dict) -> dict:
return {
"modelId": model,
diff --git a/aidial_adapter_bedrock/deployments.py b/aidial_adapter_bedrock/deployments.py
index e0e96a87..8758708a 100644
--- a/aidial_adapter_bedrock/deployments.py
+++ b/aidial_adapter_bedrock/deployments.py
@@ -43,13 +43,15 @@ class ChatCompletionDeployment(str, Enum):
STABILITY_STABLE_DIFFUSION_3_LARGE_V1 = "stability.sd3-large-v1:0"
STABILITY_STABLE_IMAGE_ULTRA_V1 = "stability.stable-image-ultra-v1:0"
- META_LLAMA2_13B_CHAT_V1 = "meta.llama2-13b-chat-v1"
- META_LLAMA2_70B_CHAT_V1 = "meta.llama2-70b-chat-v1"
META_LLAMA3_8B_INSTRUCT_V1 = "meta.llama3-8b-instruct-v1:0"
META_LLAMA3_70B_INSTRUCT_V1 = "meta.llama3-70b-instruct-v1:0"
- META_LLAMA3_1_405B_INSTRUCT_V1 = "meta.llama3-1-405b-instruct-v1:0"
- META_LLAMA3_1_70B_INSTRUCT_V1 = "meta.llama3-1-70b-instruct-v1:0"
META_LLAMA3_1_8B_INSTRUCT_V1 = "meta.llama3-1-8b-instruct-v1:0"
+ META_LLAMA3_1_70B_INSTRUCT_V1 = "meta.llama3-1-70b-instruct-v1:0"
+ META_LLAMA3_1_405B_INSTRUCT_V1 = "meta.llama3-1-405b-instruct-v1:0"
+ META_LLAMA3_2_1B_INSTRUCT_V1 = "us.meta.llama3-2-1b-instruct-v1:0"
+ META_LLAMA3_2_3B_INSTRUCT_V1 = "us.meta.llama3-2-3b-instruct-v1:0"
+ META_LLAMA3_2_11B_INSTRUCT_V1 = "us.meta.llama3-2-11b-instruct-v1:0"
+ META_LLAMA3_2_90B_INSTRUCT_V1 = "us.meta.llama3-2-90b-instruct-v1:0"
COHERE_COMMAND_TEXT_V14 = "cohere.command-text-v14"
COHERE_COMMAND_LIGHT_TEXT_V14 = "cohere.command-light-text-v14"
diff --git a/aidial_adapter_bedrock/llm/converse/__init__.py b/aidial_adapter_bedrock/llm/converse/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/aidial_adapter_bedrock/llm/converse/adapter.py b/aidial_adapter_bedrock/llm/converse/adapter.py
new file mode 100644
index 00000000..3da26bb9
--- /dev/null
+++ b/aidial_adapter_bedrock/llm/converse/adapter.py
@@ -0,0 +1,180 @@
+from typing import Any, Awaitable, Callable, List, Tuple
+
+from aidial_sdk.chat_completion import Message as DialMessage
+
+from aidial_adapter_bedrock.bedrock import Bedrock
+from aidial_adapter_bedrock.dial_api.request import ModelParameters
+from aidial_adapter_bedrock.dial_api.storage import FileStorage
+from aidial_adapter_bedrock.llm.chat_model import (
+ ChatCompletionAdapter,
+ keep_last,
+ turn_based_partitioner,
+)
+from aidial_adapter_bedrock.llm.consumer import Consumer
+from aidial_adapter_bedrock.llm.converse.input import (
+ extract_converse_system_prompt,
+ to_converse_messages,
+ to_converse_tools,
+)
+from aidial_adapter_bedrock.llm.converse.output import (
+ process_non_streaming,
+ process_streaming,
+)
+from aidial_adapter_bedrock.llm.converse.types import (
+ ConverseDeployment,
+ ConverseMessage,
+ ConverseRequestWrapper,
+ ConverseTools,
+ InferenceConfig,
+)
+from aidial_adapter_bedrock.llm.errors import ValidationError
+from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string
+from aidial_adapter_bedrock.llm.truncate_prompt import (
+ DiscardedMessages,
+ truncate_prompt,
+)
+from aidial_adapter_bedrock.utils.json import remove_nones
+from aidial_adapter_bedrock.utils.list import omit_by_indices
+from aidial_adapter_bedrock.utils.list_projection import ListProjection
+
+ConverseMessages = List[Tuple[ConverseMessage, Any]]
+
+
+class ConverseAdapter(ChatCompletionAdapter):
+ deployment: str
+ bedrock: Bedrock
+ storage: FileStorage | None
+
+ tokenize_text: Callable[[str], int] = default_tokenize_string
+ input_tokenizer_factory: Callable[
+ [ConverseDeployment, ConverseRequestWrapper],
+ Callable[[ConverseMessages], Awaitable[int]],
+ ]
+ support_tools: bool
+ partitioner: Callable[[ConverseMessages], List[int]] = (
+ turn_based_partitioner
+ )
+
+ async def _discard_messages(
+ self, params: ConverseRequestWrapper, max_prompt_tokens: int | None
+ ) -> Tuple[DiscardedMessages | None, ConverseRequestWrapper]:
+ if max_prompt_tokens is None:
+ return None, params
+
+ discarded_messages, messages = await truncate_prompt(
+ messages=params.messages.list,
+ tokenizer=self.input_tokenizer_factory(self.deployment, params),
+ keep_message=keep_last,
+ partitioner=self.partitioner,
+ model_limit=None,
+ user_limit=max_prompt_tokens,
+ )
+
+ return list(
+ params.messages.to_original_indices(discarded_messages)
+ ), ConverseRequestWrapper(
+ messages=ListProjection(
+ omit_by_indices(messages, discarded_messages)
+ ),
+ system=params.system,
+ inferenceConfig=params.inferenceConfig,
+ toolConfig=params.toolConfig,
+ )
+
+ async def count_prompt_tokens(
+ self, params: ModelParameters, messages: List[DialMessage]
+ ) -> int:
+ converse_params = await self.construct_converse_params(messages, params)
+ return await self.input_tokenizer_factory(
+ self.deployment, converse_params
+ )(converse_params.messages.list)
+
+ async def count_completion_tokens(self, string: str) -> int:
+ return self.tokenize_text(string)
+
+ async def compute_discarded_messages(
+ self, params: ModelParameters, messages: List[DialMessage]
+ ) -> DiscardedMessages | None:
+ converse_params = await self.construct_converse_params(messages, params)
+ discarded_messages, _ = await self._discard_messages(
+ converse_params, params.max_prompt_tokens
+ )
+ return discarded_messages
+
+ def get_tool_config(self, params: ModelParameters) -> ConverseTools | None:
+ if params.tool_config and not self.support_tools:
+ raise ValidationError("Tools are not supported")
+ return (
+ to_converse_tools(params.tool_config)
+ if params.tool_config
+ else None
+ )
+
+ async def construct_converse_params(
+ self,
+ messages: List[DialMessage],
+ params: ModelParameters,
+ ) -> ConverseRequestWrapper:
+ system_prompt_extraction = extract_converse_system_prompt(messages)
+ converse_messages = await to_converse_messages(
+ system_prompt_extraction.non_system_messages,
+ self.storage,
+ start_offset=system_prompt_extraction.system_message_count,
+ )
+ system_message = system_prompt_extraction.system_prompt
+ if not converse_messages.list:
+ raise ValidationError("List of messages must not be empty")
+
+ return ConverseRequestWrapper(
+ system=[system_message] if system_message else None,
+ messages=converse_messages,
+ inferenceConfig=InferenceConfig(
+ **remove_nones(
+ {
+ "temperature": params.temperature,
+ "topP": params.top_p,
+ "maxTokens": params.max_tokens,
+ "stopSequences": params.stop,
+ }
+ )
+ ),
+ toolConfig=self.get_tool_config(params),
+ )
+
+ def is_stream(self, params: ModelParameters) -> bool:
+ return params.stream
+
+ async def chat(
+ self,
+ consumer: Consumer,
+ params: ModelParameters,
+ messages: List[DialMessage],
+ ) -> None:
+
+ converse_params = await self.construct_converse_params(messages, params)
+ discarded_messages, converse_params = await self._discard_messages(
+ converse_params, params.max_prompt_tokens
+ )
+ if not converse_params.messages.raw_list:
+ raise ValidationError("No messages left after truncation")
+
+ consumer.set_discarded_messages(discarded_messages)
+
+ if self.is_stream(params):
+ await process_streaming(
+ params=params,
+ stream=(
+ await self.bedrock.aconverse_streaming(
+ self.deployment, **converse_params.to_request()
+ )
+ ),
+ consumer=consumer,
+ )
+ else:
+ process_non_streaming(
+ params=params,
+ response=await self.bedrock.aconverse_non_streaming(
+ self.deployment, **converse_params.to_request()
+ ),
+ consumer=consumer,
+ )
diff --git a/aidial_adapter_bedrock/llm/converse/constants.py b/aidial_adapter_bedrock/llm/converse/constants.py
new file mode 100644
index 00000000..562c92b6
--- /dev/null
+++ b/aidial_adapter_bedrock/llm/converse/constants.py
@@ -0,0 +1,12 @@
+from aidial_sdk.chat_completion import FinishReason as DialFinishReason
+
+from aidial_adapter_bedrock.llm.converse.types import ConverseStopReason
+
+CONVERSE_TO_DIAL_FINISH_REASON = {
+ ConverseStopReason.END_TURN: DialFinishReason.STOP,
+ ConverseStopReason.TOOL_USE: DialFinishReason.TOOL_CALLS,
+ ConverseStopReason.MAX_TOKENS: DialFinishReason.LENGTH,
+ ConverseStopReason.STOP_SEQUENCE: DialFinishReason.STOP,
+ ConverseStopReason.GUARDRAIL_INTERVENED: DialFinishReason.CONTENT_FILTER,
+ ConverseStopReason.CONTENT_FILTERED: DialFinishReason.CONTENT_FILTER,
+}
diff --git a/aidial_adapter_bedrock/llm/converse/input.py b/aidial_adapter_bedrock/llm/converse/input.py
new file mode 100644
index 00000000..1df45a2d
--- /dev/null
+++ b/aidial_adapter_bedrock/llm/converse/input.py
@@ -0,0 +1,325 @@
+import json
+from dataclasses import dataclass
+from typing import List, Set, Tuple, assert_never
+
+from aidial_sdk.chat_completion import FunctionCall as DialFunctionCall
+from aidial_sdk.chat_completion import Message as DialMessage
+from aidial_sdk.chat_completion import (
+ MessageContentImagePart,
+ MessageContentTextPart,
+)
+from aidial_sdk.chat_completion import Role as DialRole
+from aidial_sdk.chat_completion import ToolCall as DialToolCall
+from aidial_sdk.exceptions import RuntimeServerError
+
+from aidial_adapter_bedrock.dial_api.request import ToolsConfig
+from aidial_adapter_bedrock.dial_api.resource import (
+ AttachmentResource,
+ URLResource,
+)
+from aidial_adapter_bedrock.dial_api.storage import FileStorage
+from aidial_adapter_bedrock.llm.converse.types import (
+ ConverseContentPart,
+ ConverseMessage,
+ ConverseRole,
+ ConverseTextPart,
+ ConverseToolResultPart,
+ ConverseTools,
+ ConverseToolSpec,
+ ConverseToolUsePart,
+)
+from aidial_adapter_bedrock.llm.errors import ValidationError
+from aidial_adapter_bedrock.utils.list import group_by
+from aidial_adapter_bedrock.utils.list_projection import ListProjection
+
+
+def to_converse_role(role: DialRole) -> ConverseRole:
+ """
+ Converse API accepts only 'user' and 'assistant' roles
+ """
+ match role:
+ case DialRole.USER | DialRole.TOOL | DialRole.FUNCTION:
+ return ConverseRole.USER
+ case DialRole.ASSISTANT:
+ return ConverseRole.ASSISTANT
+ case DialRole.SYSTEM:
+ raise ValidationError("System messages are not allowed")
+ case _:
+ assert_never(role)
+
+
+def to_converse_tools(tools_config: ToolsConfig) -> ConverseTools:
+ tools: list[ConverseToolSpec] = []
+ for function in tools_config.functions:
+ tools.append(
+ {
+ "toolSpec": {
+ "name": function.name,
+ "description": function.description or "",
+ "inputSchema": {
+ "json": function.parameters
+ or {"type": "object", "properties": {}}
+ },
+ }
+ }
+ )
+
+ return {
+ "tools": tools,
+ "toolChoice": ({"any": {}} if tools_config.required else {"auto": {}}),
+ }
+
+
+def function_call_to_content_part(
+ dial_call: DialFunctionCall,
+) -> ConverseToolUsePart:
+ return {
+ "toolUse": {
+ "toolUseId": dial_call.name,
+ "name": dial_call.name,
+ "input": json.loads(dial_call.arguments),
+ }
+ }
+
+
+def tool_call_to_content_part(
+ dial_call: DialToolCall,
+) -> ConverseToolUsePart:
+ return {
+ "toolUse": {
+ "toolUseId": dial_call.id,
+ "name": dial_call.function.name,
+ "input": json.loads(dial_call.function.arguments),
+ }
+ }
+
+
+def function_result_to_content_part(
+ message: DialMessage,
+) -> ConverseToolResultPart:
+ if message.role != DialRole.FUNCTION:
+ raise RuntimeServerError(
+ "Function result message is expected to have function role"
+ )
+ if not message.name or not isinstance(message.content, str):
+ raise RuntimeServerError(
+ "Function result message is expected to have function name and plain text content"
+ )
+
+ return {
+ "toolResult": {
+ "toolUseId": message.name,
+ "content": [{"text": message.content}],
+ "status": "success",
+ }
+ }
+
+
+def tool_result_to_content_part(
+ message: DialMessage,
+) -> ConverseToolResultPart:
+ if message.role != DialRole.TOOL:
+ raise RuntimeServerError(
+ "Tool result message is expected to have tool role"
+ )
+ if not message.tool_call_id or not isinstance(message.content, str):
+ raise RuntimeServerError(
+ "Tool result message is expected to have tool call id and plain text content"
+ )
+
+ try:
+ json_content = json.loads(message.content)
+ return {
+ "toolResult": {
+ "toolUseId": message.tool_call_id,
+ "content": [{"json": json_content}],
+ "status": "success",
+ }
+ }
+ except json.JSONDecodeError:
+ return {
+ "toolResult": {
+ "toolUseId": message.tool_call_id,
+ "content": [{"text": message.content}],
+ "status": "success",
+ }
+ }
+
+
+def to_converse_image_type(type: str) -> str:
+ if type == "image/png":
+ return "png"
+ if type == "image/jpeg":
+ return "jpeg"
+ raise RuntimeServerError(f"Unsupported image type: {type}")
+
+
+async def _get_converse_message_content(
+ message: DialMessage,
+ storage: FileStorage | None,
+ supported_image_types: list[str] | None = None,
+) -> List[ConverseContentPart]:
+
+ if message.role == DialRole.FUNCTION:
+ return [function_result_to_content_part(message)]
+ elif message.role == DialRole.TOOL:
+ return [tool_result_to_content_part(message)]
+
+ content = []
+ match message.content:
+ case str():
+ content.append({"text": message.content})
+ case list():
+ for part in message.content:
+ match part:
+ case MessageContentTextPart():
+ content.append({"text": part.text})
+ case MessageContentImagePart():
+ resource = await URLResource(
+ url=part.image_url.url,
+ supported_types=supported_image_types,
+ ).download(storage)
+ content.append(
+ {
+ "image": {
+ "format": to_converse_image_type(
+ resource.type
+ ),
+ "source": {
+ "bytes": resource.data,
+ },
+ }
+ }
+ )
+ case None:
+ pass
+ case _:
+ assert_never(message.content)
+
+ if message.custom_content and message.custom_content.attachments:
+ for attachment in message.custom_content.attachments:
+ resource = await AttachmentResource(attachment=attachment).download(
+ storage
+ )
+ content.append(
+ {
+ "image": {
+ "format": to_converse_image_type(resource.type),
+ "source": {"bytes": resource.data},
+ }
+ }
+ )
+ if message.function_call and message.tool_calls:
+ raise ValidationError(
+ "You cannot use both function call and tool calls in the same message"
+ )
+ elif message.function_call:
+ content.append(function_call_to_content_part(message.function_call))
+ elif message.tool_calls:
+ content.extend(
+ [
+ tool_call_to_content_part(tool_call)
+ for tool_call in message.tool_calls
+ ]
+ )
+
+ return content
+
+
+async def to_converse_message(
+ message: DialMessage,
+ storage: FileStorage | None,
+ supported_image_types: list[str] | None = None,
+) -> ConverseMessage:
+
+ return {
+ "role": to_converse_role(message.role),
+ "content": await _get_converse_message_content(
+ message, storage, supported_image_types
+ ),
+ }
+
+
+@dataclass
+class ExtractSystemPromptResult:
+ system_prompt: ConverseTextPart | None
+ system_message_count: int
+ non_system_messages: List[DialMessage]
+
+
+def extract_converse_system_prompt(
+ messages: List[DialMessage],
+) -> ExtractSystemPromptResult:
+ system_msgs = []
+ found_non_system = False
+ system_messages_count = 0
+ non_system_messages = []
+
+ for msg in messages:
+ if msg.role == DialRole.SYSTEM:
+ if found_non_system:
+ raise ValidationError(
+ "A system message can only follow another system message"
+ )
+ system_messages_count += 1
+ match msg.content:
+ case str():
+ system_msgs.append(msg.content)
+ case list():
+ for part in msg.content:
+ match part:
+ case MessageContentTextPart():
+ system_msgs.append(part.text)
+ case MessageContentImagePart():
+ raise ValidationError(
+ "System messages cannot contain images"
+ )
+ case None:
+ pass
+ case _:
+ assert_never(msg.content)
+ else:
+ found_non_system = True
+ non_system_messages.append(msg)
+ combined = "\n\n".join(msg for msg in system_msgs if msg)
+ return ExtractSystemPromptResult(
+ system_prompt=ConverseTextPart(text=combined) if combined else None,
+ system_message_count=system_messages_count,
+ non_system_messages=non_system_messages,
+ )
+
+
+async def to_converse_messages(
+ messages: List[DialMessage],
+ storage: FileStorage | None,
+ # Offset for system messages at the beginning
+ start_offset: int = 0,
+) -> ListProjection[ConverseMessage]:
+ def _merge(
+ a: Tuple[ConverseMessage, Set[int]],
+ b: Tuple[ConverseMessage, Set[int]],
+ ) -> Tuple[ConverseMessage, Set[int]]:
+ (msg1, set1), (msg2, set2) = a, b
+
+ content1 = msg1["content"]
+ content2 = msg2["content"]
+
+ return {
+ "role": msg1["role"],
+ "content": content1 + content2,
+ }, set1 | set2
+
+ converted = [
+ (await to_converse_message(msg, storage), set([idx]))
+ for idx, msg in enumerate(messages, start=start_offset)
+ ]
+
+ # Merge messages with the same roles to achieve an alternation of user-assistant roles.
+ return ListProjection(
+ group_by(
+ lst=converted,
+ key=lambda msg: msg[0]["role"],
+ init=lambda msg: msg,
+ merge=_merge,
+ )
+ )
diff --git a/aidial_adapter_bedrock/llm/converse/output.py b/aidial_adapter_bedrock/llm/converse/output.py
new file mode 100644
index 00000000..83d07813
--- /dev/null
+++ b/aidial_adapter_bedrock/llm/converse/output.py
@@ -0,0 +1,147 @@
+import json
+from typing import Any, AsyncIterator, Dict, assert_never
+
+from aidial_sdk.chat_completion import FinishReason as DialFinishReason
+from aidial_sdk.chat_completion import FunctionCall as DialFunctionCall
+from aidial_sdk.chat_completion import ToolCall as DialToolCall
+from aidial_sdk.exceptions import RuntimeServerError
+
+from aidial_adapter_bedrock.dial_api.request import ModelParameters
+from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
+from aidial_adapter_bedrock.llm.consumer import Consumer
+from aidial_adapter_bedrock.llm.converse.constants import (
+ CONVERSE_TO_DIAL_FINISH_REASON,
+)
+from aidial_adapter_bedrock.llm.converse.types import ConverseStopReason
+from aidial_adapter_bedrock.llm.tools.tools_config import ToolsMode
+
+
+def to_dial_finish_reason(
+ converse_stop_reason: ConverseStopReason,
+) -> DialFinishReason:
+ if converse_stop_reason not in CONVERSE_TO_DIAL_FINISH_REASON.keys():
+ raise RuntimeServerError(
+ f"Unsupported converse stop reason: {converse_stop_reason}"
+ )
+ return CONVERSE_TO_DIAL_FINISH_REASON[converse_stop_reason]
+
+
+async def process_streaming(
+ params: ModelParameters,
+ stream: AsyncIterator[Any],
+ consumer: Consumer,
+) -> None:
+ current_tool_use = None
+
+ async for event in stream:
+ if (content_block_start := event.get("contentBlockStart")) and (
+ tool_use := content_block_start.get("start", {}).get("toolUse")
+ ):
+ if current_tool_use is not None:
+ raise ValueError("Tool use already started")
+ current_tool_use = {"input": ""} | tool_use
+
+ elif content_block := event.get("contentBlockDelta"):
+ delta = content_block.get("delta", {})
+
+ if message := delta.get("text"):
+ consumer.append_content(message)
+
+ if "toolUse" in delta:
+ if current_tool_use is None:
+ raise ValueError("Received tool delta before start block")
+ else:
+ current_tool_use["input"] += delta["toolUse"].get(
+ "input", ""
+ )
+
+ elif event.get("contentBlockStop"):
+ if current_tool_use:
+
+ match params.tools_mode:
+ case ToolsMode.TOOLS:
+ consumer.create_function_tool_call(
+ tool_call=DialToolCall(
+ type="function",
+ id=current_tool_use["toolUseId"],
+ index=None,
+ function=DialFunctionCall(
+ name=current_tool_use["name"],
+ arguments=current_tool_use["input"],
+ ),
+ )
+ )
+ case ToolsMode.FUNCTIONS:
+ # ignoring multiple function calls in one response
+ if not consumer.has_function_call:
+ consumer.create_function_call(
+ function_call=DialFunctionCall(
+ name=current_tool_use["name"],
+ arguments=current_tool_use["input"],
+ )
+ )
+ case None:
+ raise RuntimeError(
+ "Tool use received without tools mode"
+ )
+ case _:
+ assert_never(params.tools_mode)
+ current_tool_use = None
+
+ elif (message_stop := event.get("messageStop")) and (
+ stop_reason := message_stop.get("stopReason")
+ ):
+ consumer.close_content(to_dial_finish_reason(stop_reason))
+
+
+def process_non_streaming(
+ params: ModelParameters,
+ response: Dict[str, Any],
+ consumer: Consumer,
+) -> None:
+ message = response["output"]["message"]
+ for content_block in message.get("content", []):
+ if "text" in content_block:
+ consumer.append_content(content_block["text"])
+ if "toolUse" in content_block:
+ match params.tools_mode:
+ case ToolsMode.TOOLS:
+ consumer.create_function_tool_call(
+ tool_call=DialToolCall(
+ type="function",
+ id=content_block["toolUse"]["toolUseId"],
+ index=None,
+ function=DialFunctionCall(
+ name=content_block["toolUse"]["name"],
+ arguments=json.dumps(
+ content_block["toolUse"]["input"]
+ ),
+ ),
+ )
+ )
+ case ToolsMode.FUNCTIONS:
+ # ignoring multiple function calls in one response
+ if not consumer.has_function_call:
+ consumer.create_function_call(
+ function_call=DialFunctionCall(
+ name=content_block["toolUse"]["name"],
+ arguments=json.dumps(
+ content_block["toolUse"]["input"]
+ ),
+ )
+ )
+ case None:
+ raise RuntimeError("Tool use received without tools mode")
+ case _:
+ assert_never(params.tools_mode)
+
+ if usage := response.get("usage"):
+ consumer.add_usage(
+ TokenUsage(
+ prompt_tokens=usage.get("inputTokens", 0),
+ completion_tokens=usage.get("outputTokens", 0),
+ )
+ )
+
+ if stop_reason := response.get("stopReason"):
+ consumer.close_content(to_dial_finish_reason(stop_reason))
diff --git a/aidial_adapter_bedrock/llm/converse/types.py b/aidial_adapter_bedrock/llm/converse/types.py
new file mode 100644
index 00000000..54b8e153
--- /dev/null
+++ b/aidial_adapter_bedrock/llm/converse/types.py
@@ -0,0 +1,137 @@
+"""
+Types for Converse API:
+https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html
+"""
+
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Literal, Required, TypedDict, Union
+
+from aidial_adapter_bedrock.utils.json import remove_nones
+from aidial_adapter_bedrock.utils.list_projection import ListProjection
+
+
+class ConverseRole(str, Enum):
+ USER = "user"
+ ASSISTANT = "assistant"
+
+
+class ConverseTextPart(TypedDict):
+ text: str
+
+
+class ConverseJsonPart(TypedDict):
+ json: dict
+
+
+class ConverseImageSource(TypedDict):
+ bytes: bytes
+
+
+class ConverseImagePartConfig(TypedDict):
+ format: Literal["png", "jpeg", "gif", "webp"]
+ source: ConverseImageSource
+
+
+class ConverseImagePart(TypedDict):
+ image: ConverseImagePartConfig
+
+
+class ConverseToolUseConfig(TypedDict):
+ toolUseId: str
+ name: str
+ # {...}|[...]|123|123.4|'string'|True|None
+ input: Any
+
+
+class ConverseToolUsePart(TypedDict):
+ toolUse: ConverseToolUseConfig
+
+
+class ConverseToolResultConfig(TypedDict):
+ toolUseId: str
+ content: list[ConverseTextPart | ConverseJsonPart]
+ status: str
+
+
+class ConverseToolResultPart(TypedDict):
+ toolResult: ConverseToolResultConfig
+
+
+ConverseContentPart = Union[
+ ConverseTextPart,
+ ConverseJsonPart,
+ ConverseImagePart,
+ ConverseToolUsePart,
+ ConverseToolResultPart,
+]
+
+
+class ConverseToolConfig(TypedDict):
+ name: str
+ description: str
+ inputSchema: dict
+
+
+class ConverseToolSpec(TypedDict):
+ toolSpec: ConverseToolConfig
+
+
+class ConverseTools(TypedDict):
+ tools: list[ConverseToolSpec]
+ toolChoice: dict
+
+
+class ConverseToolUse(TypedDict):
+ toolUse: ConverseToolUseConfig
+
+
+class ConverseStopReason(str, Enum):
+ END_TURN = "end_turn"
+ TOOL_USE = "tool_use"
+ MAX_TOKENS = "max_tokens"
+ STOP_SEQUENCE = "stop_sequence"
+ GUARDRAIL_INTERVENED = "guardrail_intervened"
+ CONTENT_FILTERED = "content_filtered"
+
+
+class ConverseMessage(TypedDict):
+ role: ConverseRole
+ content: list[ConverseContentPart]
+
+
+class InferenceConfig(TypedDict, total=False):
+ temperature: float
+ topP: float
+ maxTokens: int
+ stopSequences: list[str]
+
+
+class ConverseRequest(TypedDict, total=False):
+ messages: Required[list[ConverseMessage]]
+ system: list[ConverseTextPart]
+ inferenceConfig: InferenceConfig
+ toolConfig: ConverseTools
+
+
+@dataclass
+class ConverseRequestWrapper:
+ messages: ListProjection[ConverseMessage]
+ system: list[ConverseTextPart] | None = None
+ inferenceConfig: InferenceConfig | None = None
+ toolConfig: ConverseTools | None = None
+
+ def to_request(self) -> ConverseRequest:
+ return ConverseRequest(
+ messages=self.messages.raw_list,
+ **remove_nones(
+ {
+ "inferenceConfig": self.inferenceConfig,
+ "toolConfig": self.toolConfig,
+ "system": self.system,
+ }
+ ),
+ )
+
+
+ConverseDeployment = str
diff --git a/aidial_adapter_bedrock/llm/model/adapter.py b/aidial_adapter_bedrock/llm/model/adapter.py
index dcc39b7b..9ee6c50c 100644
--- a/aidial_adapter_bedrock/llm/model/adapter.py
+++ b/aidial_adapter_bedrock/llm/model/adapter.py
@@ -6,6 +6,7 @@
ChatCompletionDeployment,
EmbeddingsDeployment,
)
+from aidial_adapter_bedrock.dial_api.storage import create_file_storage
from aidial_adapter_bedrock.embedding.amazon.titan_image import (
AmazonTitanImageEmbeddings,
)
@@ -19,6 +20,7 @@
EmbeddingsAdapter,
)
from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter
+from aidial_adapter_bedrock.llm.converse.adapter import ConverseAdapter
from aidial_adapter_bedrock.llm.model.ai21 import AI21Adapter
from aidial_adapter_bedrock.llm.model.amazon import AmazonAdapter
from aidial_adapter_bedrock.llm.model.claude.v1_v2.adapter import (
@@ -28,9 +30,12 @@
Adapter as Claude_V3,
)
from aidial_adapter_bedrock.llm.model.cohere import CohereAdapter
-from aidial_adapter_bedrock.llm.model.llama.v2 import llama2_config
-from aidial_adapter_bedrock.llm.model.llama.v3 import llama3_config
-from aidial_adapter_bedrock.llm.model.meta import MetaAdapter
+from aidial_adapter_bedrock.llm.model.llama.v3 import (
+ ConverseAdapterWithStreamingEmulation,
+)
+from aidial_adapter_bedrock.llm.model.llama.v3 import (
+ input_tokenizer_factory as llama_tokenizer_factory,
+)
from aidial_adapter_bedrock.llm.model.stability.v1 import StabilityV1Adapter
from aidial_adapter_bedrock.llm.model.stability.v2 import StabilityV2Adapter
@@ -106,23 +111,31 @@ async def get_bedrock_adapter(
await Bedrock.acreate(aws_client_config), model
)
case (
- ChatCompletionDeployment.META_LLAMA2_13B_CHAT_V1
- | ChatCompletionDeployment.META_LLAMA2_70B_CHAT_V1
+ ChatCompletionDeployment.META_LLAMA3_8B_INSTRUCT_V1
+ | ChatCompletionDeployment.META_LLAMA3_70B_INSTRUCT_V1
+ | ChatCompletionDeployment.META_LLAMA3_1_8B_INSTRUCT_V1
+ | ChatCompletionDeployment.META_LLAMA3_2_1B_INSTRUCT_V1
+ | ChatCompletionDeployment.META_LLAMA3_2_3B_INSTRUCT_V1
):
- return MetaAdapter.create(
- await Bedrock.acreate(aws_client_config), model, llama2_config
+ return ConverseAdapter(
+ deployment=model,
+ bedrock=await Bedrock.acreate(aws_client_config),
+ storage=create_file_storage(api_key),
+ input_tokenizer_factory=llama_tokenizer_factory,
+ support_tools=False,
)
case (
- ChatCompletionDeployment.META_LLAMA3_8B_INSTRUCT_V1
- | ChatCompletionDeployment.META_LLAMA3_70B_INSTRUCT_V1
+ ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1
| ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1
- | ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1
- | ChatCompletionDeployment.META_LLAMA3_1_8B_INSTRUCT_V1
+ | ChatCompletionDeployment.META_LLAMA3_2_11B_INSTRUCT_V1
+ | ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1
):
- return MetaAdapter.create(
- await Bedrock.acreate(aws_client_config),
- model,
- llama3_config,
+ return ConverseAdapterWithStreamingEmulation(
+ deployment=model,
+ bedrock=await Bedrock.acreate(aws_client_config),
+ storage=create_file_storage(api_key),
+ input_tokenizer_factory=llama_tokenizer_factory,
+ support_tools=True,
)
case (
ChatCompletionDeployment.COHERE_COMMAND_TEXT_V14
diff --git a/aidial_adapter_bedrock/llm/model/llama/conf.py b/aidial_adapter_bedrock/llm/model/llama/conf.py
deleted file mode 100644
index 208792e4..00000000
--- a/aidial_adapter_bedrock/llm/model/llama/conf.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from typing import Callable, List
-
-from pydantic import BaseModel
-
-from aidial_adapter_bedrock.llm.chat_emulator import ChatEmulator
-from aidial_adapter_bedrock.llm.message import BaseMessage
-
-
-class LlamaConf(BaseModel):
- chat_partitioner: Callable[[List[BaseMessage]], List[int]]
- chat_emulator: ChatEmulator
diff --git a/aidial_adapter_bedrock/llm/model/llama/v2.py b/aidial_adapter_bedrock/llm/model/llama/v2.py
deleted file mode 100644
index 66f5e55c..00000000
--- a/aidial_adapter_bedrock/llm/model/llama/v2.py
+++ /dev/null
@@ -1,142 +0,0 @@
-"""
-Turning a chat into a prompt for the Llama2 model.
-
-The reference for the algo is [this code snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362) in the original repository.
-
-See also the [tokenizer](https://github.com/huggingface/transformers/blob/c99f25476312521d4425335f970b198da42f832d/src/transformers/models/llama/tokenization_llama.py#L415) in the transformers package.
-"""
-
-from typing import List, Optional, Tuple
-
-from pydantic import BaseModel
-
-from aidial_adapter_bedrock.llm.chat_emulator import ChatEmulator
-from aidial_adapter_bedrock.llm.errors import ValidationError
-from aidial_adapter_bedrock.llm.message import (
- AIRegularMessage,
- BaseMessage,
- HumanRegularMessage,
- SystemMessage,
-)
-from aidial_adapter_bedrock.llm.model.llama.conf import LlamaConf
-
-B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<>\n", "\n<>\n\n"
-BOS = ""
-EOS = ""
-
-
-class Dialogue(BaseModel):
- """Valid dialog structure for LLAMA2 model:
- 1. optional system message,
- 2. alternating user/assistant messages,
- 3. last user query"""
-
- system: Optional[str]
- turns: List[Tuple[str, str]]
- human: str
-
- def prepend_to_first_human_message(self, text: str) -> None:
- if self.turns:
- human, ai = self.turns[0]
- self.turns[0] = text + human, ai
- else:
- self.human = text + self.human
-
-
-def validate_chat(messages: List[BaseMessage]) -> Dialogue:
- system: Optional[str] = None
- if messages and isinstance(messages[0], SystemMessage):
- system = messages[0].text_content
- if system.strip() == "":
- system = None
- messages = messages[1:]
-
- human = messages[::2]
- ai = messages[1::2]
-
- is_valid_alternation = all(
- isinstance(msg, HumanRegularMessage) for msg in human
- ) and all(isinstance(msg, AIRegularMessage) for msg in ai)
-
- if not is_valid_alternation:
- raise ValidationError(
- "The model only supports initial optional system message and"
- " follow-up alternating human/assistant messages"
- )
-
- turns = [
- (human.text_content, assistant.text_content)
- for human, assistant in zip(human, ai)
- ]
-
- if messages and isinstance(messages[-1], HumanRegularMessage):
- last_query = messages[-1]
- else:
- raise ValidationError("The last message must be from user")
-
- return Dialogue(
- system=system,
- turns=turns,
- human=last_query.text_content,
- )
-
-
-def format_sequence(text: str, bos: bool, eos: bool) -> str:
- if bos:
- text = BOS + text
- if eos:
- text = text + EOS
- return text
-
-
-def create_chat_prompt(dialogue: Dialogue) -> str:
- system = dialogue.system
- if system is not None:
- dialogue.prepend_to_first_human_message(B_SYS + system + E_SYS)
-
- ret: List[str] = [
- format_sequence(
- f"{B_INST} {human.strip()} {E_INST} {ai.strip()} ",
- bos=True,
- eos=True,
- )
- for human, ai in dialogue.turns
- ]
-
- ret.append(
- format_sequence(
- f"{B_INST} {dialogue.human.strip()} {E_INST}",
- bos=True,
- eos=False,
- )
- )
-
- return "".join(ret)
-
-
-class LlamaChatEmulator(ChatEmulator):
- def display(self, messages: List[BaseMessage]) -> Tuple[str, List[str]]:
- dialogue = validate_chat(messages)
- return create_chat_prompt(dialogue), []
-
- def get_ai_cue(self) -> Optional[str]:
- return None
-
-
-def llama2_chat_partitioner(messages: List[BaseMessage]) -> List[int]:
- dialogue = validate_chat(messages)
-
- ret: List[int] = []
- if dialogue.system is not None:
- ret.append(1)
- ret.extend([2] * len(dialogue.turns))
- ret.append(1)
-
- return ret
-
-
-llama2_config = LlamaConf(
- chat_partitioner=llama2_chat_partitioner,
- chat_emulator=LlamaChatEmulator(),
-)
diff --git a/aidial_adapter_bedrock/llm/model/llama/v3.py b/aidial_adapter_bedrock/llm/model/llama/v3.py
index 6ec44575..182203ea 100644
--- a/aidial_adapter_bedrock/llm/model/llama/v3.py
+++ b/aidial_adapter_bedrock/llm/model/llama/v3.py
@@ -1,72 +1,40 @@
-"""
-Turning a chat into a prompt for the Llama3 model.
+import json
-See as a reference:
-https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py
-"""
-
-from typing import List, Literal, Optional, Tuple, assert_never
-
-from aidial_adapter_bedrock.llm.chat_emulator import ChatEmulator
-from aidial_adapter_bedrock.llm.message import (
- AIRegularMessage,
- BaseMessage,
- HumanRegularMessage,
- SystemMessage,
+from aidial_adapter_bedrock.dial_api.request import ModelParameters
+from aidial_adapter_bedrock.llm.converse.adapter import (
+ ConverseAdapter,
+ ConverseMessages,
)
-from aidial_adapter_bedrock.llm.model.llama.conf import LlamaConf
-
-
-def get_role(message: BaseMessage) -> Literal["system", "user", "assistant"]:
- match message:
- case SystemMessage():
- return "system"
- case HumanRegularMessage():
- return "user"
- case AIRegularMessage():
- return "assistant"
- case _:
- assert_never(message)
-
-
-def encode_header(message: BaseMessage) -> str:
- ret = ""
- ret += "<|start_header_id|>"
- ret += get_role(message)
- ret += "<|end_header_id|>"
- ret += "\n\n"
- return ret
-
-
-def encode_message(message: BaseMessage) -> str:
- ret = encode_header(message)
- ret += message.text_content.strip()
- ret += "<|eot_id|>"
- return ret
-
-
-def encode_dialog_prompt(messages: List[BaseMessage]) -> str:
- ret = ""
- ret += "<|begin_of_text|>"
- for message in messages:
- ret += encode_message(message)
- ret += encode_header(AIRegularMessage(content=""))
- return ret
+from aidial_adapter_bedrock.llm.converse.types import (
+ ConverseDeployment,
+ ConverseRequestWrapper,
+)
+from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string
-class LlamaChatEmulator(ChatEmulator):
- def display(self, messages: List[BaseMessage]) -> Tuple[str, List[str]]:
- return encode_dialog_prompt(messages), []
+class ConverseAdapterWithStreamingEmulation(ConverseAdapter):
+ """
+ Llama 3 models support tools only in the non-streaming mode.
+ So we need to run request in non-streaming mode, and then emulate streaming.
+ """
- def get_ai_cue(self) -> Optional[str]:
- return None
+ def is_stream(self, params: ModelParameters) -> bool:
+ if self.get_tool_config(params):
+ return False
+ return params.stream
-def llama3_chat_partitioner(messages: List[BaseMessage]) -> List[int]:
- return [1] * len(messages)
+def input_tokenizer_factory(
+ deployment: ConverseDeployment, params: ConverseRequestWrapper
+):
+ tool_tokens = default_tokenize_string(json.dumps(params.toolConfig))
+ system_tokens = default_tokenize_string(json.dumps(params.system))
+ async def tokenizer(msg_items: ConverseMessages) -> int:
+ tokens = sum(
+ default_tokenize_string(json.dumps(msg_item[0]))
+ for msg_item in msg_items
+ )
+ return tokens + tool_tokens + system_tokens
-llama3_config = LlamaConf(
- chat_partitioner=llama3_chat_partitioner,
- chat_emulator=LlamaChatEmulator(),
-)
+ return tokenizer
diff --git a/aidial_adapter_bedrock/llm/model/meta.py b/aidial_adapter_bedrock/llm/model/meta.py
deleted file mode 100644
index a0b6e868..00000000
--- a/aidial_adapter_bedrock/llm/model/meta.py
+++ /dev/null
@@ -1,126 +0,0 @@
-from typing import Any, AsyncIterator, Dict, List, Optional
-
-from aidial_sdk.chat_completion import Message
-from typing_extensions import override
-
-from aidial_adapter_bedrock.bedrock import (
- Bedrock,
- ResponseWithInvocationMetricsMixin,
-)
-from aidial_adapter_bedrock.dial_api.request import ModelParameters
-from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
-from aidial_adapter_bedrock.llm.chat_model import PseudoChatModel
-from aidial_adapter_bedrock.llm.consumer import Consumer
-from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_META
-from aidial_adapter_bedrock.llm.model.llama.conf import LlamaConf
-from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string
-from aidial_adapter_bedrock.llm.tools.default_emulator import (
- default_tools_emulator,
-)
-
-
-class MetaResponse(ResponseWithInvocationMetricsMixin):
- generation: str
- prompt_token_count: Optional[int]
- generation_token_count: Optional[int]
- stop_reason: Optional[str]
-
- def content(self) -> str:
- return self.generation
-
- def usage(self) -> TokenUsage:
- return TokenUsage(
- prompt_tokens=self.prompt_token_count or 0,
- completion_tokens=self.generation_token_count or 0,
- )
-
-
-def convert_params(params: ModelParameters) -> Dict[str, Any]:
- ret = {}
-
- if params.temperature is not None:
- ret["temperature"] = params.temperature
-
- if params.top_p is not None:
- ret["top_p"] = params.top_p
-
- if params.max_tokens is not None:
- ret["max_gen_len"] = params.max_tokens
- else:
- # Choosing reasonable default
- ret["max_gen_len"] = DEFAULT_MAX_TOKENS_META
-
- return ret
-
-
-def create_request(prompt: str, params: Dict[str, Any]) -> Dict[str, Any]:
- return {"prompt": prompt, **params}
-
-
-async def chunks_to_stream(
- chunks: AsyncIterator[dict], usage: TokenUsage
-) -> AsyncIterator[str]:
- async for chunk in chunks:
- resp = MetaResponse.parse_obj(chunk)
- usage.accumulate(resp.usage_by_metrics())
- yield resp.content()
-
-
-async def response_to_stream(
- response: dict, usage: TokenUsage
-) -> AsyncIterator[str]:
- resp = MetaResponse.parse_obj(response)
- usage.accumulate(resp.usage())
- yield resp.content()
-
-
-class MetaAdapter(PseudoChatModel):
- model: str
- client: Bedrock
-
- @classmethod
- def create(cls, client: Bedrock, model: str, conf: LlamaConf):
- return cls(
- client=client,
- model=model,
- tokenize_string=default_tokenize_string,
- tools_emulator=default_tools_emulator,
- chat_emulator=conf.chat_emulator,
- partitioner=conf.chat_partitioner,
- )
-
- @override
- def preprocess_messages(self, messages: List[Message]) -> List[Message]:
- messages = super().preprocess_messages(messages)
-
- # Llama behaves strangely on empty prompt:
- # it generate empty string, but claims to used up all available completion tokens.
- # So replace it with a single space.
- for msg in messages:
- msg.content = msg.content or " "
-
- return messages
-
- async def predict(
- self, consumer: Consumer, params: ModelParameters, prompt: str
- ):
- args = create_request(prompt, convert_params(params))
-
- usage = TokenUsage()
-
- if params.stream:
- chunks = self.client.ainvoke_streaming(self.model, args)
- stream = chunks_to_stream(chunks, usage)
- else:
- response, _headers = await self.client.ainvoke_non_streaming(
- self.model, args
- )
- stream = response_to_stream(response, usage)
-
- stream = self.post_process_stream(stream, params, self.chat_emulator)
-
- async for content in stream:
- consumer.append_content(content)
- consumer.close_content()
-
- consumer.add_usage(usage)
diff --git a/poetry.lock b/poetry.lock
index dd3bdb8e..38ad6c76 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -337,41 +337,41 @@ uvloop = ["uvloop (>=0.15.2)"]
[[package]]
name = "boto3"
-version = "1.28.57"
+version = "1.35.41"
description = "The AWS SDK for Python"
optional = false
-python-versions = ">= 3.7"
+python-versions = ">=3.8"
files = [
- {file = "boto3-1.28.57-py3-none-any.whl", hash = "sha256:5ddf24cf52c7fb6aaa332eaa08ae8c2afc8f2d1e8860680728533dd573904e32"},
- {file = "boto3-1.28.57.tar.gz", hash = "sha256:e2d2824ba6459b330d097e94039a9c4f96ae3f4bcdc731d620589ad79dcd16d3"},
+ {file = "boto3-1.35.41-py3-none-any.whl", hash = "sha256:2bf7e7f376aee52155fc4ae4487f29333a6bcdf3a05c3bc4fede10b972d951a6"},
+ {file = "boto3-1.35.41.tar.gz", hash = "sha256:e74bc6d69c04ca611b7f58afe08e2ded6cb6504a4a80557b656abeefee395f88"},
]
[package.dependencies]
-botocore = ">=1.31.57,<1.32.0"
+botocore = ">=1.35.41,<1.36.0"
jmespath = ">=0.7.1,<2.0.0"
-s3transfer = ">=0.7.0,<0.8.0"
+s3transfer = ">=0.10.0,<0.11.0"
[package.extras]
crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "botocore"
-version = "1.31.57"
+version = "1.35.41"
description = "Low-level, data-driven core of boto 3."
optional = false
-python-versions = ">= 3.7"
+python-versions = ">=3.8"
files = [
- {file = "botocore-1.31.57-py3-none-any.whl", hash = "sha256:af006248276ff8e19e3ec7214478f6257035eb40aed865e405486500471ae71b"},
- {file = "botocore-1.31.57.tar.gz", hash = "sha256:301436174635bec739b225b840fc365ca00e5c1a63e5b2a19ee679d204e01b78"},
+ {file = "botocore-1.35.41-py3-none-any.whl", hash = "sha256:915c4d81e3a0be3b793c1e2efdf19af1d0a9cd4a2d8de08ee18216c14d67764b"},
+ {file = "botocore-1.35.41.tar.gz", hash = "sha256:8a09a32136df8768190a6c92f0240cd59c30deb99c89026563efadbbed41fa00"},
]
[package.dependencies]
jmespath = ">=0.7.1,<2.0.0"
python-dateutil = ">=2.1,<3.0.0"
-urllib3 = ">=1.25.4,<1.27"
+urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}
[package.extras]
-crt = ["awscrt (==0.16.26)"]
+crt = ["awscrt (==0.22.0)"]
[[package]]
name = "certifi"
@@ -2190,20 +2190,20 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "s3transfer"
-version = "0.7.0"
+version = "0.10.3"
description = "An Amazon S3 Transfer Manager"
optional = false
-python-versions = ">= 3.7"
+python-versions = ">=3.8"
files = [
- {file = "s3transfer-0.7.0-py3-none-any.whl", hash = "sha256:10d6923c6359175f264811ef4bf6161a3156ce8e350e705396a7557d6293c33a"},
- {file = "s3transfer-0.7.0.tar.gz", hash = "sha256:fd3889a66f5fe17299fe75b82eae6cf722554edca744ca5d5fe308b104883d2e"},
+ {file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"},
+ {file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"},
]
[package.dependencies]
-botocore = ">=1.12.36,<2.0a.0"
+botocore = ">=1.33.2,<2.0a.0"
[package.extras]
-crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"]
+crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"]
[[package]]
name = "setuptools"
@@ -2652,4 +2652,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it
[metadata]
lock-version = "2.0"
python-versions = "^3.11,<4.0"
-content-hash = "831f19b9b372a53821d8076e6900345d47804fab89d5606f8f919b03d7ed7647"
+content-hash = "a99ddf2ca2b4ce8adf619fb1b52c1ec3d62d476cad7723c362830d8fe6a22741"
diff --git a/pyproject.toml b/pyproject.toml
index 91051e27..e4fdf61e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,8 +15,8 @@ repository = "https://github.com/epam/ai-dial-adapter-bedrock/"
[tool.poetry.dependencies]
python = "^3.11,<4.0"
-boto3 = "1.28.57"
-botocore = "1.31.57"
+boto3 = "1.35.41"
+botocore = "1.35.41"
aidial-sdk = {version = "0.14.0", extras = ["telemetry"]}
anthropic = {version = "0.28.1", extras = ["bedrock"]}
fastapi = "0.115.2"
diff --git a/tests/integration_tests/constants.py b/tests/integration_tests/constants.py
index 8504d015..6c15a9bc 100644
--- a/tests/integration_tests/constants.py
+++ b/tests/integration_tests/constants.py
@@ -1,6 +1,14 @@
+from pathlib import Path
+
from aidial_adapter_bedrock.utils.resource import Resource
BLUE_PNG_PICTURE = Resource.from_base64(
type="image/png",
data_base64="iVBORw0KGgoAAAANSUhEUgAAAAMAAAADCAIAAADZSiLoAAAAF0lEQVR4nGNkYPjPwMDAwMDAxAADCBYAG10BBdmz9y8AAAAASUVORK5CYII=",
)
+CURRENT_DIR = Path(__file__).parent
+SAMPLE_DOG_IMAGE_PATH = CURRENT_DIR / "images" / "dog-sample-image.png"
+SAMPLE_DOG_RESOURCE = Resource(
+ type="image/png",
+ data=SAMPLE_DOG_IMAGE_PATH.read_bytes(),
+)
diff --git a/tests/integration_tests/test_chat_completion.py b/tests/integration_tests/test_chat_completion.py
index adf8cb1d..a1030a60 100644
--- a/tests/integration_tests/test_chat_completion.py
+++ b/tests/integration_tests/test_chat_completion.py
@@ -16,7 +16,7 @@
UpstreamConfig,
)
from aidial_adapter_bedrock.deployments import ChatCompletionDeployment
-from tests.integration_tests.constants import BLUE_PNG_PICTURE
+from tests.integration_tests.constants import SAMPLE_DOG_RESOURCE
from tests.utils.openai import (
GET_WEATHER_FUNCTION,
ChatCompletionResult,
@@ -70,14 +70,16 @@ class TestCase:
functions: List[Function] | None
tools: List[ChatCompletionToolParam] | None
+ temperature: float = 0.0
def get_id(self):
max_tokens_str = f"maxt={self.max_tokens}" if self.max_tokens else ""
stop_sequence_str = f"stop={self.stop}" if self.stop else ""
n_str = f"n={self.n}" if self.n else ""
+ temperature_str = f"temp={self.temperature}" if self.temperature else ""
return sanitize_test_name(
f"{self.deployment.value} {self.streaming} {max_tokens_str} "
- f"{stop_sequence_str} {n_str} {self.name}"
+ f"{stop_sequence_str} {n_str} {temperature_str} {self.name}"
)
@@ -99,13 +101,17 @@ def get_id(self):
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_US: _WEST,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_V2: _WEST,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_V2_US: _WEST,
- ChatCompletionDeployment.META_LLAMA2_13B_CHAT_V1: _WEST,
- ChatCompletionDeployment.META_LLAMA2_70B_CHAT_V1: _WEST,
ChatCompletionDeployment.META_LLAMA3_8B_INSTRUCT_V1: _WEST,
ChatCompletionDeployment.META_LLAMA3_70B_INSTRUCT_V1: _WEST,
- ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1: _WEST,
- ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1: _WEST,
ChatCompletionDeployment.META_LLAMA3_1_8B_INSTRUCT_V1: _WEST,
+ ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1: _WEST,
+ ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1: _WEST,
+ # Llama 3.2 1B is too unstable in responses for integration tests
+ # Sometimes it cannot calculate 2+2
+ # ChatCompletionDeployment.META_LLAMA3_2_1B_INSTRUCT_V1: _WEST,
+ ChatCompletionDeployment.META_LLAMA3_2_3B_INSTRUCT_V1: _WEST,
+ ChatCompletionDeployment.META_LLAMA3_2_11B_INSTRUCT_V1: _WEST,
+ ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1: _WEST,
ChatCompletionDeployment.COHERE_COMMAND_TEXT_V14: _WEST,
ChatCompletionDeployment.COHERE_COMMAND_LIGHT_TEXT_V14: _WEST,
}
@@ -127,6 +133,9 @@ def supports_tools(deployment: ChatCompletionDeployment) -> bool:
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_HAIKU_EU,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_OPUS,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_OPUS_US,
+ ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1,
+ ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1,
+ ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1,
]
@@ -134,6 +143,8 @@ def supports_parallel_tool_calls(deployment: ChatCompletionDeployment) -> bool:
return deployment not in [
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_V2,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_V2_US,
+ ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1,
+ ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1,
] and supports_tools(deployment)
@@ -141,6 +152,13 @@ def is_llama3(deployment: ChatCompletionDeployment) -> bool:
return deployment in [
ChatCompletionDeployment.META_LLAMA3_8B_INSTRUCT_V1,
ChatCompletionDeployment.META_LLAMA3_70B_INSTRUCT_V1,
+ ChatCompletionDeployment.META_LLAMA3_1_8B_INSTRUCT_V1,
+ ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1,
+ ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1,
+ ChatCompletionDeployment.META_LLAMA3_2_1B_INSTRUCT_V1,
+ ChatCompletionDeployment.META_LLAMA3_2_3B_INSTRUCT_V1,
+ ChatCompletionDeployment.META_LLAMA3_2_11B_INSTRUCT_V1,
+ ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1,
]
@@ -184,7 +202,10 @@ def is_ai21(deployment: ChatCompletionDeployment) -> bool:
def is_vision_model(deployment: ChatCompletionDeployment) -> bool:
- return is_claude3(deployment)
+ return is_claude3(deployment) or deployment in [
+ ChatCompletionDeployment.META_LLAMA3_2_11B_INSTRUCT_V1,
+ ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1,
+ ]
def are_tools_emulated(deployment: ChatCompletionDeployment) -> bool:
@@ -218,6 +239,7 @@ def test_case(
stop: List[str] | None = None,
functions: List[Function] | None = None,
tools: List[ChatCompletionToolParam] | None = None,
+ temperature: float = 0.0,
) -> None:
test_cases.append(
TestCase(
@@ -232,6 +254,7 @@ def test_case(
n,
functions,
tools,
+ temperature,
)
)
@@ -304,6 +327,14 @@ def dial_recall_expected(r: ChatCompletionResult):
expected_empty_message_error = streaming_error(
cohere_invalid_request_error
)
+ elif is_llama3(deployment):
+ expected_empty_message_error = streaming_error(
+ ExpectedException(
+ type=BadRequestError,
+ message="Add text to the text field, and try again.",
+ status_code=400,
+ )
+ )
test_case(
name="empty user message",
@@ -325,6 +356,14 @@ def dial_recall_expected(r: ChatCompletionResult):
expected_whitespace_message = streaming_error(
cohere_invalid_request_error
)
+ elif is_llama3(deployment):
+ expected_whitespace_message = streaming_error(
+ ExpectedException(
+ type=BadRequestError,
+ message="Add text to the text field, and try again.",
+ status_code=400,
+ )
+ )
test_case(
name="single space user message",
@@ -337,16 +376,16 @@ def dial_recall_expected(r: ChatCompletionResult):
content = "describe the image"
for idx, user_message in enumerate(
[
- user_with_attachment_data(content, BLUE_PNG_PICTURE),
- user_with_attachment_url(content, BLUE_PNG_PICTURE),
- user_with_image_url(content, BLUE_PNG_PICTURE),
+ user_with_attachment_data(content, SAMPLE_DOG_RESOURCE),
+ user_with_attachment_url(content, SAMPLE_DOG_RESOURCE),
+ user_with_image_url(content, SAMPLE_DOG_RESOURCE),
]
):
test_case(
name=f"describe image {idx}",
max_tokens=100,
messages=[sys("be a helpful assistant"), user_message], # type: ignore
- expected=lambda s: "blue" in s.content.lower(),
+ expected=lambda s: "dog" in s.content.lower(),
)
test_case(
@@ -370,10 +409,17 @@ def dial_recall_expected(r: ChatCompletionResult):
)
if is_llama3(deployment):
+
test_case(
- name="out of turn",
+ name="out_of_turn",
messages=[ai("hello"), user("what's 7+5?")],
- expected=lambda s: "12" in s.content.lower(),
+ expected=streaming_error(
+ ExpectedException(
+ type=BadRequestError,
+ message="A conversation must start with a user message",
+ status_code=400,
+ )
+ ),
)
test_case(
@@ -405,11 +451,13 @@ def dial_recall_expected(r: ChatCompletionResult):
query = f"What's the temperature in {' and in '.join(city_names)} in celsius?"
init_messages = [
- sys("act as a helpful assistant"),
user("2+3=?"),
ai("5"),
user(query),
]
+ # Llama 3 works badly with system messages along tools
+ if not is_llama3(deployment):
+ init_messages.insert(0, sys("act as a helpful assistant"))
def create_fun_args(city: str):
return {
@@ -433,6 +481,7 @@ def check_fun_args(city: str):
expected=lambda s, n=city_names[0]: is_valid_function_call(
s.function_call, fun_name, check_fun_args(n)
),
+ temperature=1 if is_llama3(deployment) else 0.0,
)
function_req = ai_function(
@@ -454,6 +503,7 @@ def check_fun_args(city: str):
expected=lambda s, t=city_temps[0]: s.content_contains_all(
[t]
),
+ temperature=1 if is_llama3(deployment) else 0.0,
)
else:
test_case(
@@ -467,6 +517,7 @@ def check_fun_args(city: str):
expected=lambda s, n=city_names[1]: is_valid_function_call(
s.function_call, fun_name, check_fun_args(n)
),
+ temperature=1 if is_llama3(deployment) else 0.0,
)
# Tools
@@ -501,6 +552,7 @@ def _check(id: str) -> bool:
)
for idx in range(len(n))
),
+ temperature=1 if is_llama3(deployment) else 0.0,
)
tool_reqs = ai_tools(
@@ -523,6 +575,7 @@ def _check(id: str) -> bool:
messages=[*init_messages, tool_reqs, *tool_resps],
tools=[tool],
expected=lambda s, t=city_temps: s.content_contains_all(t),
+ temperature=1 if is_llama3(deployment) else 0.0,
)
return test_cases
@@ -562,6 +615,7 @@ async def run_chat_completion() -> ChatCompletionResult:
test.n,
test.functions,
test.tools,
+ test.temperature,
)
if isinstance(test.expected, ExpectedException):
diff --git a/tests/integration_tests/test_stable_diffusion.py b/tests/integration_tests/test_stable_diffusion.py
index dd412d3b..05e5b4ac 100644
--- a/tests/integration_tests/test_stable_diffusion.py
+++ b/tests/integration_tests/test_stable_diffusion.py
@@ -1,5 +1,4 @@
import base64
-from pathlib import Path
from typing import Dict
from unittest.mock import patch
@@ -12,7 +11,10 @@
)
from aidial_adapter_bedrock.deployments import ChatCompletionDeployment
from aidial_adapter_bedrock.utils.resource import Resource
-from tests.integration_tests.constants import BLUE_PNG_PICTURE
+from tests.integration_tests.constants import (
+ BLUE_PNG_PICTURE,
+ SAMPLE_DOG_RESOURCE,
+)
from tests.utils.mock_storage import MockFileStorage
from tests.utils.openai import (
user,
@@ -34,13 +36,6 @@
]
VISION_MODEL = ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_US
-CURRENT_DIR = Path(__file__).parent
-SAMPLE_DOG_IMAGE_PATH = CURRENT_DIR / "images" / "dog-sample-image.png"
-SAMPLE_DOG_RESOURCE = Resource(
- type="image/png",
- data=SAMPLE_DOG_IMAGE_PATH.read_bytes(),
-)
-
def get_upstream_headers(region: str) -> Dict[str, str]:
return {
diff --git a/tests/unit_tests/chat_emulation/test_llama2_chat.py b/tests/unit_tests/chat_emulation/test_llama2_chat.py
deleted file mode 100644
index a0010a87..00000000
--- a/tests/unit_tests/chat_emulation/test_llama2_chat.py
+++ /dev/null
@@ -1,168 +0,0 @@
-from typing import List, Optional
-
-import pytest
-
-from aidial_adapter_bedrock.llm.chat_model import keep_last_and_system_messages
-from aidial_adapter_bedrock.llm.errors import ValidationError
-from aidial_adapter_bedrock.llm.message import BaseMessage
-from aidial_adapter_bedrock.llm.model.llama.v2 import llama2_config
-from aidial_adapter_bedrock.llm.truncate_prompt import (
- DiscardedMessages,
- TruncatePromptError,
- compute_discarded_messages,
-)
-from tests.utils.messages import ai, sys, user
-
-llama2_chat_emulator = llama2_config.chat_emulator
-llama2_chat_partitioner = llama2_config.chat_partitioner
-
-
-async def truncate_prompt_by_words(
- messages: List[BaseMessage],
- user_limit: int,
- model_limit: Optional[int] = None,
-) -> DiscardedMessages | TruncatePromptError:
- async def _tokenize_by_words(messages: List[BaseMessage]) -> int:
- return sum(len(msg.text_content.split()) for msg in messages)
-
- return await compute_discarded_messages(
- messages=messages,
- tokenizer=_tokenize_by_words,
- keep_message=keep_last_and_system_messages,
- partitioner=llama2_chat_partitioner,
- model_limit=model_limit,
- user_limit=user_limit,
- )
-
-
-def test_construction_single_message():
- messages: List[BaseMessage] = [
- user(" human message1 "),
- ]
-
- text, stop_sequences = llama2_chat_emulator.display(messages)
-
- assert stop_sequences == []
- assert text == "[INST] human message1 [/INST]"
-
-
-def test_construction_many_without_system():
- messages = [
- user(" human message1 "),
- ai(" ai message1 "),
- user(" human message2 "),
- ]
-
- text, stop_sequences = llama2_chat_emulator.display(messages)
-
- assert stop_sequences == []
- assert text == "".join(
- [
- "[INST] human message1 [/INST]",
- " ai message1 ",
- "[INST] human message2 [/INST]",
- ]
- )
-
-
-def test_construction_many_with_system():
- messages = [
- sys(" system message1 "),
- user(" human message1 "),
- ai(" ai message1 "),
- user(" human message2 "),
- ]
-
- text, stop_sequences = llama2_chat_emulator.display(messages)
-
- assert stop_sequences == []
- assert text == "".join(
- [
- "[INST] <>\n system message1 \n<>\n\n human message1 [/INST]",
- " ai message1 ",
- "[INST] human message2 [/INST]",
- ]
- )
-
-
-def test_invalid_alternation():
- messages = [
- ai(" ai message1 "),
- user(" human message1 "),
- user(" human message2 "),
- ]
-
- with pytest.raises(ValidationError) as exc_info:
- llama2_chat_emulator.display(messages)
-
- assert exc_info.value.message == (
- "The model only supports initial optional system message and"
- " follow-up alternating human/assistant messages"
- )
-
-
-def test_invalid_last_message():
- messages = [
- user(" human message1 "),
- ai(" ai message1 "),
- user(" human message2 "),
- ai(" ai message2 "),
- ]
-
- with pytest.raises(ValidationError) as exc_info:
- llama2_chat_emulator.display(messages)
-
- assert exc_info.value.message == "The last message must be from user"
-
-
-turns_sys = [
- sys("system"),
- user("hello"),
- ai("hi"),
- user("ping"),
- ai("pong"),
- user("improvise"),
-]
-
-turns_no_sys = turns_sys[1:]
-
-
-@pytest.mark.asyncio
-@pytest.mark.parametrize(
- "messages, user_limit, expected",
- [
- (
- turns_sys,
- 1,
- "The requested maximum prompt tokens is 1. "
- "However, the system messages and the last user message resulted in 2 tokens. "
- "Please reduce the length of the messages or increase the maximum prompt tokens.",
- ),
- (turns_sys, 2, [1, 2, 3, 4]),
- (turns_sys, 3, [1, 2, 3, 4]),
- (turns_sys, 4, [1, 2]),
- (turns_sys, 5, [1, 2]),
- (turns_sys, 6, []),
- (turns_no_sys, 1, [0, 1, 2, 3]),
- (turns_no_sys, 2, [0, 1, 2, 3]),
- (turns_no_sys, 3, [0, 1]),
- (turns_no_sys, 4, [0, 1]),
- (turns_no_sys, 5, []),
- ],
-)
-async def test_multi_turn_dialogue(
- messages: List[BaseMessage],
- user_limit: int,
- expected: DiscardedMessages | str,
-):
- discarded_messages = await truncate_prompt_by_words(
- messages=messages, user_limit=user_limit
- )
-
- if isinstance(expected, str):
- assert (
- isinstance(discarded_messages, TruncatePromptError)
- and discarded_messages.print() == expected
- )
- else:
- assert discarded_messages == expected
diff --git a/tests/unit_tests/converse/test_converse_adapter.py b/tests/unit_tests/converse/test_converse_adapter.py
new file mode 100644
index 00000000..15ce6084
--- /dev/null
+++ b/tests/unit_tests/converse/test_converse_adapter.py
@@ -0,0 +1,427 @@
+from dataclasses import dataclass
+from typing import List
+
+import pytest
+from aidial_sdk.chat_completion.request import (
+ Function,
+ FunctionCall,
+ ImageURL,
+ Message,
+ MessageContentImagePart,
+ MessageContentTextPart,
+ Role,
+ ToolCall,
+)
+
+from aidial_adapter_bedrock.aws_client_config import AWSClientConfig
+from aidial_adapter_bedrock.bedrock import Bedrock
+from aidial_adapter_bedrock.dial_api.request import ModelParameters
+from aidial_adapter_bedrock.llm.converse.adapter import ConverseAdapter
+from aidial_adapter_bedrock.llm.converse.types import (
+ ConverseImagePart,
+ ConverseImagePartConfig,
+ ConverseImageSource,
+ ConverseMessage,
+ ConverseRequestWrapper,
+ ConverseRole,
+ ConverseTextPart,
+ ConverseToolResultPart,
+ ConverseToolUseConfig,
+ ConverseToolUsePart,
+ InferenceConfig,
+)
+from aidial_adapter_bedrock.llm.errors import ValidationError
+from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig
+from aidial_adapter_bedrock.utils.list_projection import ListProjection
+from tests.integration_tests.constants import BLUE_PNG_PICTURE
+
+
+async def _input_tokenizer_factory(_deployment, _params):
+ async def _test_tokenizer(_messages) -> int:
+ return 100
+
+ return _test_tokenizer
+
+
+@dataclass
+class ExpectedException:
+ type: type[Exception]
+ message: str
+
+
+@dataclass
+class TestCase:
+ __test__ = False
+ name: str
+ messages: List[Message]
+ params: ModelParameters
+ expected_output: ConverseRequestWrapper | None = None
+ expected_error: ExpectedException | None = None
+
+
+default_inference_config = InferenceConfig(stopSequences=[])
+TEST_CASES = [
+ TestCase(
+ name="plain_message",
+ messages=[Message(role=Role.USER, content="Hello, world!")],
+ params=ModelParameters(tool_config=None),
+ expected_output=ConverseRequestWrapper(
+ inferenceConfig=default_inference_config,
+ messages=ListProjection(
+ list=[
+ (
+ ConverseMessage(
+ role=ConverseRole.USER,
+ content=[ConverseTextPart(text="Hello, world!")],
+ ),
+ {0},
+ )
+ ]
+ ),
+ ),
+ ),
+ TestCase(
+ name="system_message",
+ messages=[
+ Message(role=Role.SYSTEM, content="You are a helpful assistant."),
+ Message(role=Role.USER, content="Hello!"),
+ ],
+ params=ModelParameters(tool_config=None),
+ expected_output=ConverseRequestWrapper(
+ inferenceConfig=default_inference_config,
+ system=[ConverseTextPart(text="You are a helpful assistant.")],
+ messages=ListProjection(
+ list=[
+ (
+ ConverseMessage(
+ role=ConverseRole.USER,
+ content=[ConverseTextPart(text="Hello!")],
+ ),
+ {1},
+ )
+ ]
+ ),
+ ),
+ ),
+ TestCase(
+ name="system_message_after_user",
+ messages=[
+ Message(role=Role.SYSTEM, content="You are a helpful assistant."),
+ Message(role=Role.USER, content="Hello!"),
+ Message(role=Role.SYSTEM, content="You are a helpful assistant."),
+ ],
+ params=ModelParameters(tool_config=None),
+ expected_error=ExpectedException(
+ type=ValidationError,
+ message="A system message can only follow another system message",
+ ),
+ ),
+ TestCase(
+ name="multiple_system_messages",
+ messages=[
+ Message(role=Role.SYSTEM, content="You are a helpful assistant."),
+ Message(role=Role.SYSTEM, content="You are also very friendly."),
+ Message(role=Role.USER, content="Hello!"),
+ ],
+ params=ModelParameters(tool_config=None),
+ expected_output=ConverseRequestWrapper(
+ inferenceConfig=default_inference_config,
+ system=[
+ ConverseTextPart(
+ text="You are a helpful assistant.\n\nYou are also very friendly."
+ ),
+ ],
+ messages=ListProjection(
+ list=[
+ (
+ ConverseMessage(
+ role=ConverseRole.USER,
+ content=[ConverseTextPart(text="Hello!")],
+ ),
+ {2},
+ )
+ ]
+ ),
+ ),
+ ),
+ TestCase(
+ name="system_message_multiple_parts",
+ messages=[
+ Message(
+ role=Role.SYSTEM,
+ content=[
+ MessageContentTextPart(
+ type="text", text="You are a helpful assistant."
+ ),
+ MessageContentTextPart(
+ type="text", text="You are also very friendly."
+ ),
+ ],
+ ),
+ Message(role=Role.USER, content="Hello!"),
+ ],
+ params=ModelParameters(tool_config=None),
+ expected_output=ConverseRequestWrapper(
+ inferenceConfig=default_inference_config,
+ system=[
+ ConverseTextPart(
+ text="You are a helpful assistant.\n\nYou are also very friendly."
+ ),
+ ],
+ messages=ListProjection(
+ list=[
+ (
+ ConverseMessage(
+ role=ConverseRole.USER,
+ content=[ConverseTextPart(text="Hello!")],
+ ),
+ {1},
+ )
+ ]
+ ),
+ ),
+ ),
+ TestCase(
+ name="system_message_with_forbidden_image",
+ messages=[
+ Message(
+ role=Role.SYSTEM,
+ content=[
+ MessageContentTextPart(
+ type="text", text="You are a helpful assistant."
+ ),
+ MessageContentImagePart(
+ type="image_url",
+ image_url=ImageURL(url=BLUE_PNG_PICTURE.to_data_url()),
+ ),
+ ],
+ ),
+ Message(role=Role.USER, content="Hello!"),
+ ],
+ params=ModelParameters(tool_config=None),
+ expected_error=ExpectedException(
+ type=ValidationError,
+ message="System messages cannot contain images",
+ ),
+ ),
+ TestCase(
+ name="tools_convert",
+ messages=[
+ Message(role=Role.USER, content="What's the weather?"),
+ Message(
+ role=Role.ASSISTANT,
+ content=None,
+ tool_calls=[
+ ToolCall(
+ index=0,
+ id="call_123",
+ type="function",
+ function=FunctionCall(
+ name="get_weather",
+ arguments='{"location": "London"}',
+ ),
+ )
+ ],
+ ),
+ Message(
+ role=Role.TOOL,
+ content='{"temperature": "20C"}',
+ tool_call_id="call_123",
+ ),
+ ],
+ params=ModelParameters(
+ tool_config=ToolsConfig(
+ functions=[
+ Function(
+ name="get_weather",
+ description="Get the weather",
+ parameters={"type": "object", "properties": {}},
+ )
+ ],
+ required=True,
+ tool_ids=None,
+ )
+ ),
+ expected_output=ConverseRequestWrapper(
+ inferenceConfig=default_inference_config,
+ toolConfig={
+ "tools": [
+ {
+ "toolSpec": {
+ "name": "get_weather",
+ "description": "Get the weather",
+ "inputSchema": {
+ "json": {"properties": {}, "type": "object"}
+ },
+ }
+ }
+ ],
+ "toolChoice": {"any": {}},
+ },
+ messages=ListProjection(
+ list=[
+ (
+ ConverseMessage(
+ role=ConverseRole.USER,
+ content=[
+ ConverseTextPart(text="What's the weather?")
+ ],
+ ),
+ {0},
+ ),
+ (
+ ConverseMessage(
+ role=ConverseRole.ASSISTANT,
+ content=[
+ ConverseToolUsePart(
+ toolUse=ConverseToolUseConfig(
+ toolUseId="call_123",
+ name="get_weather",
+ input={"location": "London"},
+ )
+ )
+ ],
+ ),
+ {1},
+ ),
+ (
+ ConverseMessage(
+ role=ConverseRole.USER,
+ content=[
+ ConverseToolResultPart(
+ toolResult={
+ "toolUseId": "call_123",
+ "content": [
+ {"json": {"temperature": "20C"}}
+ ],
+ "status": "success",
+ }
+ )
+ ],
+ ),
+ {2},
+ ),
+ ]
+ ),
+ ),
+ ),
+ TestCase(
+ name="content_parts",
+ messages=[
+ Message(
+ role=Role.USER,
+ content=[
+ MessageContentTextPart(type="text", text="Hello!"),
+ MessageContentImagePart(
+ type="image_url",
+ image_url=ImageURL(url=BLUE_PNG_PICTURE.to_data_url()),
+ ),
+ ],
+ )
+ ],
+ params=ModelParameters(tool_config=None),
+ expected_output=ConverseRequestWrapper(
+ inferenceConfig=default_inference_config,
+ messages=ListProjection(
+ list=[
+ (
+ ConverseMessage(
+ role=ConverseRole.USER,
+ content=[
+ ConverseTextPart(text="Hello!"),
+ ConverseImagePart(
+ image=ConverseImagePartConfig(
+ format="png",
+ source=ConverseImageSource(
+ bytes=BLUE_PNG_PICTURE.data
+ ),
+ )
+ ),
+ ],
+ ),
+ {0},
+ )
+ ]
+ ),
+ ),
+ ),
+ TestCase(
+ name="shrink_messages",
+ messages=[
+ Message(role=Role.USER, content="Say hello."),
+ Message(role=Role.USER, content="And have a good day."),
+ Message(
+ role=Role.ASSISTANT,
+ content="Hello",
+ ),
+ Message(
+ role=Role.ASSISTANT,
+ content=[
+ MessageContentTextPart(type="text", text="Have a nice"),
+ MessageContentTextPart(type="text", text="day!"),
+ ],
+ ),
+ ],
+ params=ModelParameters(temperature=10),
+ expected_output=ConverseRequestWrapper(
+ inferenceConfig=InferenceConfig(temperature=10, stopSequences=[]),
+ messages=ListProjection(
+ list=[
+ (
+ ConverseMessage(
+ role=ConverseRole.USER,
+ content=[
+ ConverseTextPart(text="Say hello."),
+ ConverseTextPart(text="And have a good day."),
+ ],
+ ),
+ {0, 1},
+ ),
+ (
+ ConverseMessage(
+ role=ConverseRole.ASSISTANT,
+ content=[
+ ConverseTextPart(text="Hello"),
+ ConverseTextPart(text="Have a nice"),
+ ConverseTextPart(text="day!"),
+ ],
+ ),
+ {2, 3},
+ ),
+ ]
+ ),
+ ),
+ ),
+]
+
+
+@pytest.mark.parametrize(
+ "test_case", TEST_CASES, ids=lambda test_case: test_case.name
+)
+@pytest.mark.asyncio
+async def test_converse_adapter(
+ test_case: TestCase,
+):
+ adapter = ConverseAdapter(
+ deployment="test",
+ bedrock=await Bedrock.acreate(AWSClientConfig(region="us-east-1")),
+ tokenize_text=lambda x: len(x),
+ input_tokenizer_factory=_input_tokenizer_factory, # type: ignore
+ support_tools=True,
+ storage=None,
+ )
+ construct_coro = adapter.construct_converse_params(
+ messages=test_case.messages,
+ params=test_case.params,
+ )
+
+ if test_case.expected_error is not None:
+ with pytest.raises(test_case.expected_error.type) as exc_info:
+ converse_request = await construct_coro
+ assert hasattr(exc_info.value, "message")
+ error_message = getattr(exc_info.value, "message")
+ assert isinstance(error_message, str)
+ assert error_message == test_case.expected_error.message
+ else:
+ converse_request = await construct_coro
+ assert converse_request == test_case.expected_output
diff --git a/tests/unit_tests/converse/test_to_converse_message.py b/tests/unit_tests/converse/test_to_converse_message.py
new file mode 100644
index 00000000..e86052a4
--- /dev/null
+++ b/tests/unit_tests/converse/test_to_converse_message.py
@@ -0,0 +1,184 @@
+import pytest
+from aidial_sdk.chat_completion import FunctionCall
+from aidial_sdk.chat_completion import Message as DialMessage
+from aidial_sdk.chat_completion import Role as DialRole
+from aidial_sdk.chat_completion import ToolCall
+
+from aidial_adapter_bedrock.llm.converse.input import to_converse_message
+
+
+@pytest.mark.asyncio
+async def test_to_converse_message_text():
+ dial_message = DialMessage(role=DialRole.USER, content="Hello, world!")
+ converse_message = await to_converse_message(dial_message, storage=None)
+
+ assert converse_message == {
+ "role": "user",
+ "content": [{"text": "Hello, world!"}],
+ }
+
+
+@pytest.mark.asyncio
+async def test_to_converse_message_assistant():
+ dial_message = DialMessage(role=DialRole.ASSISTANT, content="Hello")
+ converse_message = await to_converse_message(dial_message, storage=None)
+
+ assert converse_message == {
+ "role": "assistant",
+ "content": [{"text": "Hello"}],
+ }
+
+
+@pytest.mark.asyncio
+async def test_to_converse_message_function_call_no_content():
+ dial_message = DialMessage(
+ role=DialRole.ASSISTANT,
+ function_call=FunctionCall(
+ name="get_weather", arguments='{"city": "Paris"}'
+ ),
+ )
+ converse_message = await to_converse_message(dial_message, storage=None)
+ assert converse_message == {
+ "role": "assistant",
+ "content": [
+ {
+ "toolUse": {
+ "toolUseId": "get_weather",
+ "name": "get_weather",
+ "input": {"city": "Paris"},
+ }
+ },
+ ],
+ }
+
+
+@pytest.mark.asyncio
+async def test_to_converse_message_function_call_with_content():
+ dial_message = DialMessage(
+ role=DialRole.ASSISTANT,
+ content="Calling a function",
+ function_call=FunctionCall(
+ name="get_weather", arguments='{"city": "Paris"}'
+ ),
+ )
+ converse_message = await to_converse_message(dial_message, storage=None)
+ assert converse_message == {
+ "role": "assistant",
+ "content": [
+ {"text": "Calling a function"},
+ {
+ "toolUse": {
+ "toolUseId": "get_weather",
+ "name": "get_weather",
+ "input": {"city": "Paris"},
+ }
+ },
+ ],
+ }
+
+
+@pytest.mark.asyncio
+async def test_to_converse_message_tool_call_no_content():
+ dial_message = DialMessage(
+ role=DialRole.ASSISTANT,
+ tool_calls=[
+ ToolCall(
+ index=None,
+ id="123",
+ type="function",
+ function=FunctionCall(
+ name="get_weather", arguments='{"city": "Paris"}'
+ ),
+ )
+ ],
+ )
+ converse_message = await to_converse_message(dial_message, storage=None)
+ assert converse_message == {
+ "role": "assistant",
+ "content": [
+ {
+ "toolUse": {
+ "toolUseId": "123",
+ "name": "get_weather",
+ "input": {"city": "Paris"},
+ }
+ },
+ ],
+ }
+
+
+@pytest.mark.asyncio
+async def test_to_converse_message_tool_call_with_content():
+ dial_message = DialMessage(
+ role=DialRole.ASSISTANT,
+ content="Calling a function",
+ tool_calls=[
+ ToolCall(
+ index=None,
+ id="123",
+ type="function",
+ function=FunctionCall(
+ name="get_weather", arguments='{"city": "Paris"}'
+ ),
+ )
+ ],
+ )
+ converse_message = await to_converse_message(dial_message, storage=None)
+ assert converse_message == {
+ "role": "assistant",
+ "content": [
+ {"text": "Calling a function"},
+ {
+ "toolUse": {
+ "toolUseId": "123",
+ "name": "get_weather",
+ "input": {"city": "Paris"},
+ }
+ },
+ ],
+ }
+
+
+@pytest.mark.asyncio
+async def test_to_converse_message_multiple_tool_calls():
+ dial_message = DialMessage(
+ role=DialRole.ASSISTANT,
+ tool_calls=[
+ ToolCall(
+ index=None,
+ id="123",
+ type="function",
+ function=FunctionCall(
+ name="get_weather", arguments='{"city": "Paris"}'
+ ),
+ ),
+ ToolCall(
+ index=None,
+ id="456",
+ type="function",
+ function=FunctionCall(
+ name="get_weather", arguments='{"city": "London"}'
+ ),
+ ),
+ ],
+ )
+ converse_message = await to_converse_message(dial_message, storage=None)
+ assert converse_message == {
+ "role": "assistant",
+ "content": [
+ {
+ "toolUse": {
+ "toolUseId": "123",
+ "name": "get_weather",
+ "input": {"city": "Paris"},
+ }
+ },
+ {
+ "toolUse": {
+ "toolUseId": "456",
+ "name": "get_weather",
+ "input": {"city": "London"},
+ }
+ },
+ ],
+ }
diff --git a/tests/unit_tests/test_endpoints.py b/tests/unit_tests/test_endpoints.py
index 79afa80b..3748ae25 100644
--- a/tests/unit_tests/test_endpoints.py
+++ b/tests/unit_tests/test_endpoints.py
@@ -36,13 +36,15 @@
),
(ChatCompletionDeployment.STABILITY_STABLE_IMAGE_ULTRA_V1, False, True),
(ChatCompletionDeployment.STABILITY_STABLE_IMAGE_CORE_V1, False, True),
- (ChatCompletionDeployment.META_LLAMA2_13B_CHAT_V1, True, True),
- (ChatCompletionDeployment.META_LLAMA2_70B_CHAT_V1, True, True),
(ChatCompletionDeployment.META_LLAMA3_8B_INSTRUCT_V1, True, True),
(ChatCompletionDeployment.META_LLAMA3_70B_INSTRUCT_V1, True, True),
- (ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1, True, True),
- (ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1, True, True),
(ChatCompletionDeployment.META_LLAMA3_1_8B_INSTRUCT_V1, True, True),
+ (ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1, True, True),
+ (ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1, True, True),
+ (ChatCompletionDeployment.META_LLAMA3_2_1B_INSTRUCT_V1, True, True),
+ (ChatCompletionDeployment.META_LLAMA3_2_3B_INSTRUCT_V1, True, True),
+ (ChatCompletionDeployment.META_LLAMA3_2_11B_INSTRUCT_V1, True, True),
+ (ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1, True, True),
(ChatCompletionDeployment.COHERE_COMMAND_TEXT_V14, True, True),
(ChatCompletionDeployment.COHERE_COMMAND_LIGHT_TEXT_V14, True, True),
]
diff --git a/tests/utils/openai.py b/tests/utils/openai.py
index 04f8475c..81033d00 100644
--- a/tests/utils/openai.py
+++ b/tests/utils/openai.py
@@ -197,6 +197,7 @@ async def chat_completion(
n: Optional[int],
functions: List[Function] | None,
tools: List[ChatCompletionToolParam] | None,
+ temperature: float = 0.0,
) -> ChatCompletionResult:
async def get_response() -> ChatCompletion:
response = await client.chat.completions.create(
@@ -205,7 +206,7 @@ async def get_response() -> ChatCompletion:
stream=stream,
stop=stop,
max_tokens=max_tokens,
- temperature=0.0,
+ temperature=temperature,
n=n,
function_call="auto" if functions is not None else NOT_GIVEN,
functions=functions or NOT_GIVEN,