From 784fd888d23380a77c671b54943c3a8fcbf25298 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 17 Dec 2024 18:23:15 +0100 Subject: [PATCH] fix: fixes to Bedrock Chat Generator for compatibility with the new ChatMessage --- .../amazon_bedrock/chat/chat_generator.py | 10 +++--- .../tests/test_chat_generator.py | 36 +++++++++---------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 499fe1c24..bcf11414c 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -278,12 +278,10 @@ def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[C # Process each content block separately for content_block in content_blocks: if "text" in content_block: - replies.append(ChatMessage.from_assistant(content=content_block["text"], meta=base_meta.copy())) + replies.append(ChatMessage.from_assistant(content_block["text"], meta=base_meta.copy())) elif "toolUse" in content_block: replies.append( - ChatMessage.from_assistant( - content=json.dumps(content_block["toolUse"]), meta=base_meta.copy() - ) + ChatMessage.from_assistant(json.dumps(content_block["toolUse"]), meta=base_meta.copy()) ) return replies @@ -334,9 +332,9 @@ def process_streaming_response( pass tool_content = json.dumps(current_tool_use) - replies.append(ChatMessage.from_assistant(content=tool_content, meta=base_meta.copy())) + replies.append(ChatMessage.from_assistant(tool_content, meta=base_meta.copy())) elif current_content: - replies.append(ChatMessage.from_assistant(content=current_content, meta=base_meta.copy())) + replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy())) elif "messageStop" in event: # not 100% correct for multiple messages but no way around it diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 8eb29729c..c2122163c 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -163,9 +163,9 @@ def test_default_inference_params(self, model_name, chat_messages): first_reply = replies[0] assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" + assert first_reply.text, "First reply has no content" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert "paris" in first_reply.text.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" if first_reply.meta and "usage" in first_reply.meta: @@ -197,9 +197,9 @@ def streaming_callback(chunk: StreamingChunk): first_reply = replies[0] assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" + assert first_reply.text, "First reply has no content" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert "paris" in first_reply.text.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) @@ -246,7 +246,7 @@ def test_tools_use(self, model_name): first_reply = replies[0] assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" + assert first_reply.text, "First reply has no content" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" assert first_reply.meta, "First reply has no metadata" @@ -254,9 +254,9 @@ def test_tools_use(self, model_name): if len(replies) > 1: second_reply = replies[1] assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" - assert second_reply.content, "Second reply has no content" + assert second_reply.text, "Second reply has no content" assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" - tool_call = json.loads(second_reply.content) + tool_call = json.loads(second_reply.text) assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" @@ -266,7 +266,7 @@ def test_tools_use(self, model_name): else: # case where the model returns the tool call as the first message # double check that the tool call is correct - tool_call = json.loads(first_reply.content) + tool_call = json.loads(first_reply.text) assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" @@ -318,7 +318,7 @@ def test_tools_use_with_streaming(self, model_name): first_reply = replies[0] assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" + assert first_reply.text, "First reply has no content" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" assert first_reply.meta, "First reply has no metadata" @@ -326,9 +326,9 @@ def test_tools_use_with_streaming(self, model_name): if len(replies) > 1: second_reply = replies[1] assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" - assert second_reply.content, "Second reply has no content" + assert second_reply.text, "Second reply has no content" assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" - tool_call = json.loads(second_reply.content) + tool_call = json.loads(second_reply.text) assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" @@ -338,7 +338,7 @@ def test_tools_use_with_streaming(self, model_name): else: # case where the model returns the tool call as the first message # double check that the tool call is correct - tool_call = json.loads(first_reply.content) + tool_call = json.loads(first_reply.text) assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" @@ -361,7 +361,7 @@ def test_extract_replies_from_response(self, mock_boto3_session): replies = generator.extract_replies_from_response(text_response) assert len(replies) == 1 - assert replies[0].content == "This is a test response" + assert replies[0].text == "This is a test response" assert replies[0].role == ChatRole.ASSISTANT assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" assert replies[0].meta["finish_reason"] == "complete" @@ -381,7 +381,7 @@ def test_extract_replies_from_response(self, mock_boto3_session): replies = generator.extract_replies_from_response(tool_response) assert len(replies) == 1 - tool_content = json.loads(replies[0].content) + tool_content = json.loads(replies[0].text) assert tool_content["toolUseId"] == "123" assert tool_content["name"] == "test_tool" assert tool_content["input"] == {"key": "value"} @@ -405,8 +405,8 @@ def test_extract_replies_from_response(self, mock_boto3_session): replies = generator.extract_replies_from_response(mixed_response) assert len(replies) == 2 - assert replies[0].content == "Let me help you with that. I'll use the search tool to find the answer." - tool_content = json.loads(replies[1].content) + assert replies[0].text == "Let me help you with that. I'll use the search tool to find the answer." + tool_content = json.loads(replies[1].text) assert tool_content["toolUseId"] == "456" assert tool_content["name"] == "search_tool" assert tool_content["input"] == {"query": "test"} @@ -446,13 +446,13 @@ def test_callback(chunk: StreamingChunk): # Verify final replies assert len(replies) == 2 # Check text reply - assert replies[0].content == "Let me help you." + assert replies[0].text == "Let me help you." assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" assert replies[0].meta["finish_reason"] == "complete" assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} # Check tool use reply - tool_content = json.loads(replies[1].content) + tool_content = json.loads(replies[1].text) assert tool_content["toolUseId"] == "123" assert tool_content["name"] == "search_tool" assert tool_content["input"] == {"query": "test"}