diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index 417c661fe..33cdca69c 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -2,6 +2,12 @@ ## [unreleased] +### 🚀 Features + +- New generator for the bedrock converse API (#977) + +## [unreleased] + ### 🐛 Bug Fixes - *(Bedrock)* Allow tools kwargs for AWS Bedrock Claude model (#976) diff --git a/integrations/amazon_bedrock/examples/converse_generator_example.py b/integrations/amazon_bedrock/examples/converse_generator_example.py new file mode 100644 index 000000000..82969b6c2 --- /dev/null +++ b/integrations/amazon_bedrock/examples/converse_generator_example.py @@ -0,0 +1,46 @@ +from haystack import Pipeline + +from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockConverseGenerator +from haystack_integrations.components.generators.amazon_bedrock.converse.utils import ConverseMessage, ToolConfig + + +def get_current_weather(location: str, unit: str = "celsius") -> str: + """Get the current weather in a given location""" + return f"The weather in {location} is 22 degrees {unit}." + + +def get_current_time(timezone: str) -> str: + """Get the current time in a given timezone""" + return f"The current time in {timezone} is 14:30." + + +generator = AmazonBedrockConverseGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", + streaming_callback=print, +) + +# Create ToolConfig from functions +tool_config = ToolConfig.from_functions([get_current_weather, get_current_time]) + +tool_config_dict = tool_config.to_dict() + + +pipeline = Pipeline() +pipeline.add_component("generator", generator) + +result = pipeline.run( + data={ + "generator": { + "inference_config": { + "temperature": 0.1, + "maxTokens": 256, + "topP": 0.1, + "stopSequences": ["\\n"], + }, + "messages": [ + ConverseMessage.from_user(["What's the weather like in Paris and what time is it in New York?"]), + ], + "tool_config": tool_config_dict, + }, + }, +) diff --git a/integrations/amazon_bedrock/pydoc/config.yml b/integrations/amazon_bedrock/pydoc/config.yml index 6cb05d6f3..d84f0bac9 100644 --- a/integrations/amazon_bedrock/pydoc/config.yml +++ b/integrations/amazon_bedrock/pydoc/config.yml @@ -7,6 +7,7 @@ loaders: "haystack_integrations.common.amazon_bedrock.errors", "haystack_integrations.components.generators.amazon_bedrock.handlers", "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator", + "haystack_integrations.components.generators.amazon_bedrock.converse.converse_generator", "haystack_integrations.components.embedders.amazon_bedrock.text_embedder", "haystack_integrations.components.embedders.amazon_bedrock.document_embedder", ] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py index 2d33beb42..88139497c 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py @@ -2,6 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 from .chat.chat_generator import AmazonBedrockChatGenerator +from .converse.capabilities import MODEL_CAPABILITIES +from .converse.converse_generator import AmazonBedrockConverseGenerator +from .converse.utils import ConverseMessage, ToolConfig from .generator import AmazonBedrockGenerator -__all__ = ["AmazonBedrockGenerator", "AmazonBedrockChatGenerator"] +__all__ = [ + "AmazonBedrockGenerator", + "AmazonBedrockChatGenerator", + "AmazonBedrockConverseGenerator", + "ConverseMessage", + "ToolConfig", + "MODEL_CAPABILITIES", +] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py new file mode 100644 index 000000000..98dcd53b3 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py @@ -0,0 +1,109 @@ +from enum import Enum, auto + + +class ModelCapability(Enum): + CONVERSE = auto() + CONVERSE_STREAM = auto() + SYSTEM_PROMPTS = auto() + DOCUMENT_CHAT = auto() + VISION = auto() + TOOL_USE = auto() + STREAMING_TOOL_USE = auto() + GUARDRAILS = auto() + + +# https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html + +MODEL_CAPABILITIES = { + "ai21.jamba-instruct-.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + }, + "ai21.j2-.*": { + ModelCapability.CONVERSE, + ModelCapability.GUARDRAILS, + }, + "amazon.titan-text-.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.GUARDRAILS, + }, + "anthropic.claude-v2.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.GUARDRAILS, + }, + "anthropic.claude-3.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.VISION, + ModelCapability.TOOL_USE, + ModelCapability.STREAMING_TOOL_USE, + ModelCapability.GUARDRAILS, + }, + "cohere.command-text.*": { + ModelCapability.CONVERSE, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.GUARDRAILS, + }, + "cohere.command-light.*": { + ModelCapability.CONVERSE, + ModelCapability.GUARDRAILS, + }, + "cohere.command-r.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.TOOL_USE, + }, + "meta.llama2.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.GUARDRAILS, + }, + "meta.llama3.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.GUARDRAILS, + }, + "meta.llama3-1.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.TOOL_USE, + ModelCapability.GUARDRAILS, + }, + "mistral.mistral-.*-instruct": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.GUARDRAILS, + }, + "mistral.mistral-large.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.TOOL_USE, + ModelCapability.GUARDRAILS, + }, + "mistral.mistral-small.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.TOOL_USE, + ModelCapability.GUARDRAILS, + }, +} diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py new file mode 100644 index 000000000..876c314fa --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -0,0 +1,347 @@ +import logging +import re +from typing import Any, Callable, Dict, List, Optional, Set + +from botocore.exceptions import ClientError +from haystack import component, default_from_dict, default_to_dict +from haystack.utils.auth import Secret, deserialize_secrets_inplace +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable + +from haystack_integrations.common.amazon_bedrock.errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, +) +from haystack_integrations.common.amazon_bedrock.utils import get_aws_session + +from .capabilities import ( + MODEL_CAPABILITIES, + ModelCapability, +) +from .utils import ( + ContentBlock, + ConverseMessage, + ImageBlock, + StreamEvent, + ToolConfig, + get_stream_message, +) + +logger = logging.getLogger(__name__) + + +@component +class AmazonBedrockConverseGenerator: + """ + Completes chats using LLMs hosted on Amazon Bedrock using the converse api. + References: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html + + For example, to use the Anthropic Claude 3 Sonnet model, initialize this component with the + 'anthropic.claude-3-sonnet-20240229-v1:0' model name. + + ### Usage example + + ```python + from haystack import Pipeline + + from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockConverseGenerator + from haystack_integrations.components.generators.amazon_bedrock.converse.utils import ConverseMessage, ToolConfig + + + def get_current_weather(location: str, unit: str = "celsius") -> str: + return f"The weather in {location} is 22 degrees {unit}." + + + def get_current_time(timezone: str) -> str: + return f"The current time in {timezone} is 14:30." + + + generator = AmazonBedrockConverseGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", + # streaming_callback=print, + ) + + # Create ToolConfig from functions + tool_config = ToolConfig.from_functions([get_current_weather, get_current_time]) + + # Convert ToolConfig to dict for use in the run method + tool_config_dict = tool_config.to_dict() + + print("Tool Config:") + print(tool_config_dict) + + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + result = pipeline.run( + data={ + "generator": { + "inference_config": { + "temperature": 0.1, + "maxTokens": 256, + "topP": 0.1, + "stopSequences": ["\\n"], + }, + "messages": [ + ConverseMessage.from_user(["What's the weather like in Paris and what time is it in New York?"]), + ], + "tool_config": tool_config_dict, + }, + }, + ) + print(result) + ``` + + AmazonBedrockChatGenerator uses AWS for authentication. You can use the AWS CLI to authenticate through your IAM. + For more information on setting up an IAM identity-based policy, see [Amazon Bedrock documentation] + (https://docs.aws.amazon.com/bedrock/latest/userguide/security_iam_id-based-policy-examples.html). + + If the AWS environment is configured correctly, the AWS credentials are not required as they're loaded + automatically from the environment or the AWS configuration file. + If the AWS environment is not configured, set `aws_access_key_id`, `aws_secret_access_key`, + and `aws_region_name` as environment variables or pass them as + [Secret](https://docs.haystack.deepset.ai/v2.0/docs/secret-management) arguments. Make sure the region you set + supports Amazon Bedrock. + """ + + # according to the list provided in the toolConfig arg: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_RequestSyntax + + def __init__( + self, + model: str, + aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008 + aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 + ["AWS_SECRET_ACCESS_KEY"], strict=False + ), + aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 + aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 + aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 + inference_config: Optional[Dict[str, Any]] = None, + tool_config: Optional[ToolConfig] = None, + streaming_callback: Optional[Callable[[StreamEvent], None]] = None, + system_prompt: Optional[List[Dict[str, Any]]] = None, + ): + """ + Initializes the `AmazonBedrockConverseGenerator` with the provided parameters. The parameters are passed to the + Amazon Bedrock client. + + Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded + automatically from the environment or the AWS configuration file and do not need to be provided explicitly via + the constructor. If the AWS environment is not configured users need to provide the AWS credentials via the + constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`, + and `aws_region_name`. + + :param model: The model to use for text generation. The model must be available in Amazon Bedrock and must + be specified in the format outlined in the [Amazon Bedrock documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html). + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. Make sure the region you set supports Amazon Bedrock. + :param aws_profile_name: AWS profile name. + :param streaming_callback: A callback function called when a new token is received from the stream. + By default, the model is not set up for streaming. To enable streaming, set this parameter to a callback + function that handles the streaming chunks. The callback function receives a + [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and + switches the streaming mode on. + :param inference_config: A dictionary containing the inference configuration. The default value is None. + :param tool_config: A dictionary containing the tool configuration. The default value is None. + :param system_prompt: A list of dictionaries containing the system prompt. The default value is None. + """ + if not model: + msg = "'model' cannot be None or empty string" + raise ValueError(msg) + self.model = model + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + self.aws_region_name = aws_region_name + self.aws_profile_name = aws_profile_name + + self.inference_config = inference_config + self.tool_config = tool_config + self.streaming_callback = streaming_callback + self.system_prompt = system_prompt + self.model_capabilities = self._get_model_capabilities(model) # create the AWS session and client + + def resolve_secret(secret: Optional[Secret]) -> Optional[str]: + return secret.resolve_value() if secret else None + + try: + session = get_aws_session( + aws_access_key_id=resolve_secret(aws_access_key_id), + aws_secret_access_key=resolve_secret(aws_secret_access_key), + aws_session_token=resolve_secret(aws_session_token), + aws_region_name=resolve_secret(aws_region_name), + aws_profile_name=resolve_secret(aws_profile_name), + ) + self.client = session.client("bedrock-runtime") + except Exception as exception: + msg = ( + "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " + "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" + ) + raise AmazonBedrockConfigurationError(msg) from exception + + def _get_model_capabilities(self, model: str) -> Set[ModelCapability]: + for pattern, capabilities in MODEL_CAPABILITIES.items(): + if re.match(pattern, model): + return capabilities + unsupported_model_error = ValueError(f"Unsupported model: {model}") + raise unsupported_model_error + + @component.output_types( + message=ConverseMessage, + usage=Dict[str, Any], + metrics=Dict[str, Any], + guardrail_trace=Dict[str, Any], + stop_reason=str, + ) + def run( + self, + messages: List[ConverseMessage], + streaming_callback: Optional[Callable[[StreamEvent], None]] = None, + inference_config: Optional[Dict[str, Any]] = None, + tool_config: Optional[ToolConfig] = None, + system_prompt: Optional[List[Dict[str, Any]]] = None, + ): + streaming_callback = streaming_callback or self.streaming_callback + system_prompt = system_prompt or self.system_prompt + tool_config = tool_config or self.tool_config + + if not ( + isinstance(messages, list) + and len(messages) > 0 + and all(isinstance(message, ConverseMessage) for message in messages) + ): + msg = f"The model {self.model} requires a list of ConverseMessage objects as input." + raise ValueError(msg) + + # Check and filter messages based on model capabilities + if ModelCapability.SYSTEM_PROMPTS not in self.model_capabilities and system_prompt: + logger.warning( + f"The model {self.model} does not support system prompts. The provided system_prompt will be ignored." + ) + system_prompt = None + + if ModelCapability.VISION not in self.model_capabilities: + messages = [ + ConverseMessage( + role=msg.role, + content=ContentBlock( + content=[item for item in msg.content.content if not isinstance(item, ImageBlock)] + ), + ) + for msg in messages + ] + + if ModelCapability.DOCUMENT_CHAT not in self.model_capabilities: + logger.warning( + f"The model {self.model} does not support document chat. This feature will not be available." + ) + + if ModelCapability.TOOL_USE not in self.model_capabilities and tool_config: + logger.warning(f"The model {self.model} does not support tools. The provided tool_config will be ignored.") + tool_config = None + + if ModelCapability.STREAMING_TOOL_USE not in self.model_capabilities and streaming_callback and tool_config: + logger.warning( + f"The model {self.model} does not support streaming tool use. " + "Streaming will be disabled for tool calls." + ) + + request_kwargs = { + "modelId": self.model, + "inferenceConfig": inference_config or self.inference_config, + "messages": [message.to_dict() for message in messages], + } + + if tool_config: + request_kwargs["toolConfig"] = tool_config.to_dict() + if system_prompt: + request_kwargs["system"] = { + "text": system_prompt, + } + + try: + if streaming_callback and ModelCapability.CONVERSE_STREAM in self.model_capabilities: + converse_response = self.client.converse_stream(**request_kwargs) + response_stream = converse_response.get("stream") + message, metadata = get_stream_message( + stream=response_stream, + streaming_callback=streaming_callback, + ) + else: + converse_response = self.client.converse(**request_kwargs) + output = converse_response.get("output") + if output is None: + response_output_missing_error = "Response does not contain 'output'" + raise KeyError(response_output_missing_error) + message = ConverseMessage.from_dict(output.get("message")) + if message is None: + response_output_missing_message_error = "Response 'output' does not contain 'message'" + raise KeyError(response_output_missing_message_error) + metadata = converse_response + + return { + "message": message, + "usage": metadata.get("usage"), + "metrics": metadata.get("metrics"), + "guardrail_trace": ( + metadata.get("trace") if ModelCapability.GUARDRAILS in self.model_capabilities else None + ), + "stop_reason": metadata.get("stopReason"), + } + + except ClientError as exception: + msg = f"Could not run inference on Amazon Bedrock model {self.model} due to: {exception}" + raise AmazonBedrockInferenceError(msg) from exception + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, + aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, + aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, + aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, + aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, + model=self.model, + streaming_callback=callback_name, + system_prompt=self.system_prompt, + inference_config=self.inference_config, + tool_config=self.tool_config.to_dict() if self.tool_config else None, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockConverseGenerator": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + tool_config = data.get("init_parameters", {}).get("tool_config") + if tool_config: + data["init_parameters"]["tool_config"] = ToolConfig.from_dict(tool_config) + deserialize_secrets_inplace( + data["init_parameters"], + [ + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region_name", + "aws_profile_name", + ], + ) + return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py new file mode 100644 index 000000000..9ebf03806 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -0,0 +1,422 @@ +import inspect +import json +import logging +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union + +from botocore.eventstream import EventStream + + +@dataclass +class ToolSpec: + name: str + description: Optional[str] = None + input_schema: Dict[str, Dict] = field(default_factory=dict) + + +@dataclass +class Tool: + tool_spec: ToolSpec + + +@dataclass +class ToolChoice: + auto: Dict = field(default_factory=dict) + any: Dict = field(default_factory=dict) + tool: Optional[Dict[str, str]] = None + + +@dataclass +class ToolConfig: + tools: List[Tool] + tool_choice: Optional[ToolChoice] = None + + def __post_init__(self): + msg = "Only one of 'auto', 'any', or 'tool' can be set in tool_choice" + if self.tool_choice and sum(bool(v) for v in vars(self.tool_choice).values()) != 1: + raise ValueError(msg) + + if self.tool_choice and self.tool_choice.tool: + if "name" not in self.tool_choice.tool: + msg = "'name' is required when 'tool' is specified in tool_choice" + raise ValueError(msg) + + @staticmethod + def from_functions(functions: List[Callable]) -> "ToolConfig": + tools = [] + for func in functions: + tool_spec = ToolSpec( + name=func.__name__, + description=func.__doc__, + input_schema={ + "json": { + "type": "object", + "properties": {param: {"type": "string"} for param in inspect.signature(func).parameters}, + "required": list(inspect.signature(func).parameters.keys()), + } + }, + ) + tools.append(Tool(tool_spec=tool_spec)) + + return ToolConfig(tools=tools) + + @classmethod + def from_dict(cls, config: Dict) -> "ToolConfig": + tools = [ + Tool( + ToolSpec( + input_schema=tool["toolSpec"]["inputSchema"], + name=tool["toolSpec"]["name"], + description=tool["toolSpec"]["description"], + ) + ) + for tool in config.get( + "tools", + [], + ) + ] + + tool_choice = None + if "tool_choice" in config: + tc = config["tool_choice"] + if "auto" in tc: + tool_choice = ToolChoice(auto=tc["auto"]) + elif "any" in tc: + tool_choice = ToolChoice(any=tc["any"]) + elif "tool" in tc: + tool_choice = ToolChoice(tool={"name": tc["tool"]["name"]}) + + return cls(tools=tools, tool_choice=tool_choice) + + def to_dict(self) -> Dict[str, Any]: + result = { + "tools": [ + { + "toolSpec": { + "name": tool.tool_spec.name, + "description": tool.tool_spec.description, + "inputSchema": tool.tool_spec.input_schema, + } + } + for tool in self.tools + ] + } + if self.tool_choice: + tool_choice: Dict[str, Dict[str, Any]] = {} + if self.tool_choice.auto: + tool_choice["auto"] = self.tool_choice.auto + elif self.tool_choice.any: + tool_choice["any"] = self.tool_choice.any + elif self.tool_choice.tool: + tool_choice["tool"] = self.tool_choice.tool + result["tool_choice"] = [tool_choice] + return result + + +@dataclass +class DocumentSource: + bytes: bytes + + +@dataclass +class DocumentBlock: + SUPPORTED_FORMATS = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] + format: SUPPORTED_FORMATS + name: str + source: bytes + + +@dataclass +class GuardrailConverseContentBlock: + text: str + qualifiers: List[str] = field(default_factory=list) + + +@dataclass +class ImageSource: + bytes: bytes + + +@dataclass +class ImageBlock: + format: str + source: ImageSource + + +@dataclass +class ToolResultContentBlock: + json: Optional[Dict] = None + text: Optional[str] = None + image: Optional[ImageBlock] = None + document: Optional[DocumentBlock] = None + + +@dataclass +class ToolResultBlock: + tool_use_id: str + content: List[ToolResultContentBlock] + status: Optional[str] = None + + +@dataclass +class ToolUseBlock: + tool_use_id: str + name: str + input: Dict[str, Any] + + +class ConverseRole(str, Enum): + USER = "user" + ASSISTANT = "assistant" + + +@dataclass +class ContentBlock: + content: List[Union[DocumentBlock, GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, ToolUseBlock]] + + def __post_init__(self): + err_msg = "Content must be a list" + if not isinstance(self.content, list): + raise ValueError(err_msg) + + for item in self.content: + if not isinstance( + item, (DocumentBlock, GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, ToolUseBlock) + ): + msg = ( + f"Invalid content type: {type(item)}. Each item must be one of DocumentBlock, " + "GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, or ToolUseBlock" + ) + raise ValueError(msg) + + @staticmethod + def from_assistant(content: Sequence[Union[str, ToolUseBlock]]) -> "ContentBlock": + return ContentBlock(content=list(content)) + + def to_dict(self): + res = [] + for item in self.content: + if isinstance(item, str): + res.append({"text": item}) + elif isinstance(item, DocumentBlock): + res.append({"document": asdict(item)}) + elif isinstance(item, GuardrailConverseContentBlock): + res.append({"guardContent": asdict(item)}) + elif isinstance(item, ImageBlock): + res.append({"image": asdict(item)}) + elif isinstance(item, ToolResultBlock): + res.append({"toolResult": asdict(item)}) + elif isinstance(item, ToolUseBlock): + res.append({"toolUse": asdict(item)}) + else: + msg = f"Unsupported content type: {type(item)}" + raise ValueError(msg) + return res + + +@dataclass +class ConverseMessage: + role: ConverseRole + content: ContentBlock + + @staticmethod + def from_user( + content: List[ + Union[ + DocumentBlock, + GuardrailConverseContentBlock, + ImageBlock, + str, + ToolUseBlock, + ToolResultBlock, + ], + ], + ) -> "ConverseMessage": + return ConverseMessage( + ConverseRole.USER, + ContentBlock( + content=content, + ), + ) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> "ConverseMessage": + role = ConverseRole.ASSISTANT if data["role"] == "assistant" else ConverseRole.USER + content_blocks = [] + + for item in data["content"]: + if "text" in item: + content_blocks.append(item["text"]) + elif "image" in item: + content_blocks.append(ImageBlock(**item["image"])) + elif "document" in item: + content_blocks.append(DocumentBlock(**item["document"])) + elif "toolUse" in item: + content_blocks.append( + ToolUseBlock( + tool_use_id=item["toolUse"]["toolUseId"], + name=item["toolUse"]["name"], + input=item["toolUse"]["input"], + ) + ) + elif "toolResult" in item: + content_blocks.append(ToolResultBlock(**item["toolResult"])) + elif "guardContent" in item: + content_blocks.append(GuardrailConverseContentBlock(**item["guardContent"])) + else: + unknown_type = f"Unknown content type in message: {item}" + raise ValueError(unknown_type) + + return ConverseMessage(role, ContentBlock(content=content_blocks)) + + def to_dict(self): + return { + "role": self.role.value, + "content": self.content.to_dict(), + } + + +@dataclass +class StreamEvent: + type: str + data: Dict[str, Any] + + +def _parse_event(event: Dict[str, Any]) -> StreamEvent: + for key in ["contentBlockStart", "contentBlockDelta", "contentBlockStop", "messageStop", "messageStart"]: + if key in event: + return StreamEvent(type=key, data=event[key]) + return StreamEvent(type="metadata", data=event.get("metadata", {})) + + +def _handle_content_block_start(event: StreamEvent, current_index: int) -> Tuple[int, Union[str, ToolUseBlock]]: + new_index = event.data.get("contentBlockIndex", current_index + 1) + start_of_tool_use = event.data.get("start") + if start_of_tool_use: + return new_index, ToolUseBlock( + tool_use_id=start_of_tool_use["toolUse"]["toolUseId"], + name=start_of_tool_use["toolUse"]["name"], + input={}, + ) + return new_index, "" + + +def _handle_content_block_delta( + event: StreamEvent, current_block: Union[str, ToolUseBlock], current_tool_use_input_str: str +) -> Tuple[Union[str, ToolUseBlock], str]: + delta = event.data.get("delta", {}) + if "text" in delta: + if isinstance(current_block, str): + return current_block + delta["text"], current_tool_use_input_str + else: + return delta["text"], current_tool_use_input_str + if "toolUse" in delta: + if isinstance(current_block, ToolUseBlock): + return current_block, current_tool_use_input_str + delta["toolUse"].get("input", "") + else: + return ToolUseBlock( + tool_use_id=delta["toolUse"]["toolUseId"], + name=delta["toolUse"]["name"], + input={}, + ), delta["toolUse"].get("input", "") + return current_block, current_tool_use_input_str + + +def get_stream_message( + stream: EventStream, + streaming_callback: Callable[[StreamEvent], None], +) -> Tuple[ConverseMessage, Dict[str, Any]]: + """ + Processes a stream of messages and returns a ConverseMessage and the associated metadata. + + The stream is expected to contain the following events: + + - contentBlockStart: Indicates the start of a content block. + - contentBlockDelta: Indicates a change to the content block. + - contentBlockStop: Indicates the end of a content block. + - messageStop: Indicates the end of a message. + - metadata: Indicates metadata about the message. + + The function processes each event in the stream and returns a ConverseMessage and the associated metadata. + The ConverseMessage will contain the content of the message, + and the metadata will contain the stop reason and any other metadata from the stream. + + The function will also call the streaming_callback function + with a ConverseStreamingChunk for each event in the stream. + The ConverseStreamingChunk will contain the content and metadata from the event. + + :param stream: The stream of messages to process. + :param streaming_callback: The callback function to call with each ConverseStreamingChunk. + :return: A tuple containing the ConverseMessage and the associated metadata. + """ + current_block: Union[str, ToolUseBlock] = "" + current_tool_use_input_str: str = "" + latest_metadata: Dict[str, Any] = {} + current_index: int = 0 + streamed_contents: List[Union[str, ToolUseBlock]] = [] + + try: + for raw_event in stream: + event = _parse_event(raw_event) + + if event.type == "contentBlockStart": + if current_block: + if isinstance(current_block, str) and streamed_contents and isinstance(streamed_contents[-1], str): + streamed_contents[-1] += current_block + else: + streamed_contents.append(current_block) + current_index, current_block = _handle_content_block_start(event, current_index) + + elif event.type == "contentBlockDelta": + new_block, new_input_str = _handle_content_block_delta(event, current_block, current_tool_use_input_str) + if isinstance(new_block, ToolUseBlock) and new_block != current_block: + if current_block: + if ( + isinstance(current_block, str) + and streamed_contents + and isinstance(streamed_contents[-1], str) + ): + streamed_contents[-1] += current_block + else: + streamed_contents.append(current_block) + current_index += 1 + current_block, current_tool_use_input_str = new_block, new_input_str + + elif event.type == "contentBlockStop": + if isinstance(current_block, ToolUseBlock): + current_block.input = json.loads(current_tool_use_input_str) + current_tool_use_input_str = "" + streamed_contents.append(current_block) + elif isinstance(current_block, str): + if streamed_contents and isinstance(streamed_contents[-1], str): + streamed_contents[-1] += current_block + else: + streamed_contents.append(current_block) + current_block = "" + current_index += 1 + + elif event.type == "messageStop": + latest_metadata["stopReason"] = event.data.get("stopReason") + + latest_metadata.update(event.data if event.type == "metadata" else {}) + streaming_callback(event) + + except Exception as e: + logging.error(f"Error processing stream: {e!s}") + raise + + # Add any remaining content + if current_block: + if isinstance(current_block, str) and streamed_contents and isinstance(streamed_contents[-1], str): + streamed_contents[-1] += current_block + else: + streamed_contents.append(current_block) + + return ( + ConverseMessage( + role=ConverseRole.ASSISTANT, + content=ContentBlock.from_assistant(streamed_contents), + ), + latest_metadata, + ) diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py new file mode 100644 index 000000000..af8e621d8 --- /dev/null +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -0,0 +1,378 @@ +import json +from unittest.mock import Mock, patch + +import pytest +from botocore.exceptions import ClientError + +from haystack_integrations.common.amazon_bedrock.errors import ( + AmazonBedrockInferenceError, +) +from haystack_integrations.components.generators.amazon_bedrock import ( + MODEL_CAPABILITIES, + AmazonBedrockConverseGenerator, + ConverseMessage, + ToolConfig, +) +from haystack_integrations.components.generators.amazon_bedrock.converse.utils import ( + ConverseRole, + ImageBlock, + ImageSource, + ToolUseBlock, +) + + +def get_current_weather(location: str, unit: str = "celsius") -> str: + """Get the current weather in a given location""" + # This is a mock function, replace with actual API call + return f"The weather in {location} is 22 degrees {unit}." + + +def get_current_time(timezone: str) -> str: + """Get the current time in a given timezone""" + # This is a mock function, replace with actual time lookup + return f"The current time in {timezone} is 14:30." + + +def test_to_dict(mock_boto3_session): + """ + Test that the to_dict method returns the correct dictionary without aws credentials + """ + tool_config = ToolConfig.from_functions([get_current_weather, get_current_time]) + + generator = AmazonBedrockConverseGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", + inference_config={ + "temperature": 0.1, + "maxTokens": 256, + "topP": 0.1, + "stopSequences": ["\\n"], + }, + tool_config=tool_config, + ) + + expected_dict = { + "type": "haystack_integrations.components.generators.amazon_bedrock.converse.converse_generator." + "AmazonBedrockConverseGenerator", + "init_parameters": { + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, + "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "inference_config": { + "temperature": 0.1, + "maxTokens": 256, + "topP": 0.1, + "stopSequences": ["\\n"], + }, + "tool_config": tool_config.to_dict(), + "streaming_callback": None, + "system_prompt": None, + }, + } + + assert generator.to_dict() == expected_dict + + +def test_from_dict(mock_boto3_session): + """ + Test that the from_dict method returns the correct object + """ + tool_config = ToolConfig.from_functions([get_current_weather, get_current_time]) + + generator = AmazonBedrockConverseGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.amazon_bedrock.converse.converse_generator." + "AmazonBedrockConverseGenerator", + "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "inference_config": { + "temperature": 0.1, + "maxTokens": 256, + "topP": 0.1, + "stopSequences": ["\\n"], + }, + "tool_config": tool_config.to_dict(), + "streaming_callback": None, + "system_prompt": None, + }, + } + ) + + assert generator.inference_config["temperature"] == 0.1 + assert generator.inference_config["maxTokens"] == 256 + assert generator.inference_config["topP"] == 0.1 + assert generator.inference_config["stopSequences"] == ["\\n"] + assert generator.tool_config.to_dict() == tool_config.to_dict() + assert generator.model == "anthropic.claude-3-5-sonnet-20240620-v1:0" + + +def test_default_constructor(mock_boto3_session, set_env_variables): + """ + Test that the default constructor sets the correct values + """ + + layer = AmazonBedrockConverseGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", + ) + + assert layer.model == "anthropic.claude-3-5-sonnet-20240620-v1:0" + assert layer.inference_config is None + assert layer.tool_config is None + assert layer.streaming_callback is None + assert layer.system_prompt is None + + # assert mocked boto3 client called exactly once + mock_boto3_session.assert_called_once() + + # assert mocked boto3 client was called with the correct parameters + mock_boto3_session.assert_called_with( + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + profile_name="some_fake_profile", + region_name="fake_region", + ) + + +def test_constructor_with_empty_model(): + """ + Test that the constructor raises an error when the model is empty + """ + with pytest.raises(ValueError, match="cannot be None or empty string"): + AmazonBedrockConverseGenerator(model="") + + +def test_get_model_capabilities(mock_boto3_session): + generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0") + assert generator.model_capabilities == MODEL_CAPABILITIES["anthropic.claude-3.*"] + + generator = AmazonBedrockConverseGenerator(model="ai21.j2-ultra-instruct-v1") + assert generator.model_capabilities == MODEL_CAPABILITIES["ai21.j2-.*"] + + with pytest.raises(ValueError, match="Unsupported model"): + AmazonBedrockConverseGenerator(model="unsupported.model-v1") + + +@patch("boto3.Session") +def test_run_with_different_message_types(mock_session): + mock_client = Mock() + mock_session.return_value.client.return_value = mock_client + mock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "Hello, how can I help you?"}]}}, + "usage": {"inputTokens": 10, "outputTokens": 20}, + "metrics": {"timeToFirstToken": 0.5}, + } + + generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0") + + messages = [ + ConverseMessage.from_user(["What's the weather like?"]), + ConverseMessage.from_user([ImageBlock(format="png", source=ImageSource(bytes=b"fake_image_data"))]), + ] + + result = generator.run(messages) + + assert result["message"].role == ConverseRole.ASSISTANT + assert result["message"].content.content[0] == "Hello, how can I help you?" + assert result["usage"] == {"inputTokens": 10, "outputTokens": 20} + assert result["metrics"] == {"timeToFirstToken": 0.5} + + # Check the actual content sent to the API + mock_client.converse.assert_called_once() + call_args = mock_client.converse.call_args[1] + assert len(call_args["messages"]) == 2 + assert call_args["messages"][0]["content"] == [{"text": "What's the weather like?"}] + + +def test_streaming(mock_boto3_session): + generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0") + + mocked_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "To"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " answer"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " your questions"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": ","}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " I'll"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " need to"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " use"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " two"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " different functions"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": ":"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " one"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " to check"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " the weather"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " in Paris and another"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " to get the current"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " time in New York"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": "."}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " Let"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " me"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " fetch"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " that"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " information for"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " you."}, "contentBlockIndex": 0}}, + {"contentBlockStop": {"contentBlockIndex": 0}}, + { + "contentBlockStart": { + "start": {"toolUse": {"toolUseId": "tooluse_5Uu9EPSjQxiSsmc5Ex5MJg", "name": "get_current_weather"}}, + "contentBlockIndex": 1, + } + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"loc'}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": 'ation":'}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ' "Paris"'}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ', "u'}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": 'nit": "ce'}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": "lsius"}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '"}'}}, "contentBlockIndex": 1}}, + {"contentBlockStop": {"contentBlockIndex": 1}}, + { + "contentBlockStart": { + "start": {"toolUse": {"toolUseId": "tooluse_cbK-e15KTFqZHtwpBJ0kzg", "name": "get_current_time"}}, + "contentBlockIndex": 2, + } + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 2}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"timezon'}}, "contentBlockIndex": 2}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": 'e"'}}, "contentBlockIndex": 2}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ': "A'}}, "contentBlockIndex": 2}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": "meric"}}, "contentBlockIndex": 2}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": "a/New"}}, "contentBlockIndex": 2}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '_York"}'}}, "contentBlockIndex": 2}}, + {"contentBlockStop": {"contentBlockIndex": 2}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": {"inputTokens": 446, "outputTokens": 118, "totalTokens": 564}, + "metrics": {"latencyMs": 3930}, + } + }, + ] + + mock_stream = Mock() + mock_stream.__iter__ = Mock(return_value=iter(mocked_events)) + + generator.client.converse_stream = Mock(return_value={"stream": mock_stream}) + + chunks = [] + result = generator.run( + [ConverseMessage.from_user(["What's the weather like in Paris and what time is it in New York?"])], + streaming_callback=lambda chunk: chunks.append(chunk), + ) + + assert len(chunks) == len(mocked_events) + assert ( + result["message"].content.content[0] + == "To answer your questions, I'll need to use two different functions: one to check the weather " + "in Paris and another to get the current time in New York. Let me fetch that information for you." + ) + assert len(result["message"].content.content) == 3 + assert result["stop_reason"] == "tool_use" + + assert result["message"].role == ConverseRole.ASSISTANT + + assert result["usage"] == {"inputTokens": 446, "outputTokens": 118, "totalTokens": 564} + assert result["metrics"] == {"latencyMs": 3930} + + assert result["message"].content.content[1].name == "get_current_weather" + assert result["message"].content.content[2].name == "get_current_time" + + assert json.dumps(result["message"].content.content[1].input) == """{"location": "Paris", "unit": "celsius"}""" + assert json.dumps(result["message"].content.content[2].input) == """{"timezone": "America/New_York"}""" + + +def test_client_error_handling(mock_boto3_session): + generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0") + generator.client.converse = Mock( + side_effect=ClientError( + error_response={"Error": {"Code": "ThrottlingException", "Message": "Rate exceeded"}}, + operation_name="Converse", + ) + ) + + with pytest.raises(AmazonBedrockInferenceError, match="Could not run inference on Amazon Bedrock model"): + generator.run([ConverseMessage.from_user(["Hi"])]) + + +def test_tool_usage(mock_boto3_session): + tool_config = ToolConfig.from_functions([get_current_weather, get_current_time]) + generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0", tool_config=tool_config) + + mock_response = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "text": "I'll get the weather in Paris and the current time in New York for you. " + "To do this, I'll need to use two different tools. Let me fetch that data." + }, + { + "toolUse": { + "toolUseId": "tooluse_-Tp78_OeSq-1DSsP0B__TA", + "name": "get_current_weather", + "input": {"location": "Paris", "unit": "celsius"}, + } + }, + { + "toolUse": { + "toolUseId": "tooluse_gdYvqeiGTme7toWoV4sSKw", + "name": "get_current_time", + "input": {"timezone": "America/New_York"}, + }, + }, + ], + }, + }, + "stopReason": "tool_use", + } + generator.client.converse = Mock(return_value=mock_response) + + result = generator.run([ConverseMessage.from_user(["What's the weather in London?"])]) + + assert len(result["message"].content.content) == 3 + assert isinstance(result["message"].content.content[0], str) + assert isinstance(result["message"].content.content[1], ToolUseBlock) + assert isinstance(result["message"].content.content[2], ToolUseBlock) + assert result["stop_reason"] == "tool_use" + assert result["message"].role == ConverseRole.ASSISTANT + assert result["message"].content.content[0] == ( + "I'll get the weather in Paris and the current time in New York for you. " + "To do this, I'll need to use two different tools. Let me fetch that data." + ) + assert result["message"].content.content[1].name == "get_current_weather" + assert result["message"].content.content[2].name == "get_current_time" + assert json.dumps(result["message"].content.content[1].input) == """{"location": "Paris", "unit": "celsius"}""" + assert json.dumps(result["message"].content.content[2].input) == """{"timezone": "America/New_York"}""" + assert result["message"].content.content[1].input["location"] == "Paris" + assert result["message"].content.content[1].input["unit"] == "celsius" + assert result["message"].content.content[2].input["timezone"] == "America/New_York"