diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 575f52003..ad09fe2d3 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -211,10 +211,7 @@ def run( system_prompts = [{"text": messages[0].content}] messages = messages[1:] - messages_list = [ - {"role": msg.role.value, "content": [{"text": msg.content}]} - for msg in messages - ] + messages_list = [{"role": msg.role.value, "content": [{"text": msg.content}]} for msg in messages] try: # Build API parameters @@ -262,23 +259,17 @@ def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[C "prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0), "completion_tokens": response_body.get("usage", {}).get("outputTokens", 0), "total_tokens": response_body.get("usage", {}).get("totalTokens", 0), - } + }, } # Process each content block separately for content_block in content_blocks: if "text" in content_block: - replies.append( - ChatMessage.from_assistant( - content=content_block["text"], - meta=base_meta.copy() - ) - ) + replies.append(ChatMessage.from_assistant(content=content_block["text"], meta=base_meta.copy())) elif "toolUse" in content_block: replies.append( ChatMessage.from_assistant( - content=json.dumps(content_block["toolUse"]), - meta={**base_meta.copy()} + content=json.dumps(content_block["toolUse"]), meta=base_meta.copy() ) ) return replies @@ -305,7 +296,7 @@ def process_streaming_response( current_tool_use = { "toolUseId": tool_start["toolUseId"], "name": tool_start["name"], - "input": "" # Will accumulate deltas as string + "input": "", # Will accumulate deltas as string } elif "contentBlockDelta" in event: @@ -330,16 +321,9 @@ def process_streaming_response( pass tool_content = json.dumps(current_tool_use) - replies.append( - ChatMessage.from_assistant( - content=tool_content, - meta=base_meta.copy() - ) - ) + replies.append(ChatMessage.from_assistant(content=tool_content, meta=base_meta.copy())) elif current_content: - replies.append( - ChatMessage.from_assistant(content=current_content, meta=base_meta.copy()) - ) + replies.append(ChatMessage.from_assistant(content=current_content, meta=base_meta.copy())) elif "messageStop" in event: # not 100% correct for multiple messages but no way around it diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 2a0f7a535..5e5358218 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -7,8 +7,16 @@ from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" -MODELS_TO_TEST = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"] -MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"] +MODELS_TO_TEST = [ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "cohere.command-r-plus-v1:0", + "mistral.mistral-large-2402-v1:0", +] +MODELS_TO_TEST_WITH_TOOLS = [ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "cohere.command-r-plus-v1:0", + "mistral.mistral-large-2402-v1:0", +] # so far we've discovered these models support streaming and tool use STREAMING_TOOL_MODELS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0"] @@ -106,7 +114,9 @@ def test_constructor_with_generation_kwargs(mock_boto3_session): """ generation_kwargs = {"temperature": 0.7} - layer = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0", generation_kwargs=generation_kwargs) + layer = AmazonBedrockChatGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", generation_kwargs=generation_kwargs + ) assert layer.generation_kwargs == generation_kwargs @@ -199,19 +209,19 @@ def test_tools_use(self, model_name): "properties": { "sign": { "type": "string", - "description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ and WKRP." + "description": "The call sign for the radio station " + "for which you want the most popular song. " + "Example calls signs are WZPZ and WKRP.", } }, - "required": [ - "sign" - ] + "required": ["sign"], } - } + }, } } ], # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html - "toolChoice": {"auto": {}} + "toolChoice": {"auto": {}}, } messages = [] @@ -228,7 +238,6 @@ def test_tools_use(self, model_name): assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" assert first_reply.meta, "First reply has no metadata" - # Some models return thinking message as first and the second one as the tool call if len(replies) > 1: second_reply = replies[1] @@ -239,7 +248,9 @@ def test_tools_use(self, model_name): assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" else: # case where the model returns the tool call as the first message # double check that the tool call is correct @@ -247,7 +258,9 @@ def test_tools_use(self, model_name): assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) @pytest.mark.integration @@ -268,19 +281,19 @@ def test_tools_use_with_streaming(self, model_name): "properties": { "sign": { "type": "string", - "description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ and WKRP." + "description": "The call sign for the radio station " + "for which you want the most popular song. Example " + "calls signs are WZPZ and WKRP.", } }, - "required": [ - "sign" - ] + "required": ["sign"], } - } + }, } } ], # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html - "toolChoice": {"auto": {}} + "toolChoice": {"auto": {}}, } messages = [] @@ -307,7 +320,9 @@ def test_tools_use_with_streaming(self, model_name): assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" else: # case where the model returns the tool call as the first message # double check that the tool call is correct @@ -315,4 +330,6 @@ def test_tools_use_with_streaming(self, model_name): assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value"