Skip to content

Commit

Permalink
fix: fixes to Bedrock Chat Generator for compatibility with the new C…
Browse files Browse the repository at this point in the history
…hatMessage
  • Loading branch information
anakin87 committed Dec 17, 2024
1 parent 3a3419a commit 784fd88
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,10 @@ def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[C
# Process each content block separately
for content_block in content_blocks:
if "text" in content_block:
replies.append(ChatMessage.from_assistant(content=content_block["text"], meta=base_meta.copy()))
replies.append(ChatMessage.from_assistant(content_block["text"], meta=base_meta.copy()))
elif "toolUse" in content_block:
replies.append(
ChatMessage.from_assistant(
content=json.dumps(content_block["toolUse"]), meta=base_meta.copy()
)
ChatMessage.from_assistant(json.dumps(content_block["toolUse"]), meta=base_meta.copy())
)
return replies

Expand Down Expand Up @@ -334,9 +332,9 @@ def process_streaming_response(
pass

tool_content = json.dumps(current_tool_use)
replies.append(ChatMessage.from_assistant(content=tool_content, meta=base_meta.copy()))
replies.append(ChatMessage.from_assistant(tool_content, meta=base_meta.copy()))
elif current_content:
replies.append(ChatMessage.from_assistant(content=current_content, meta=base_meta.copy()))
replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy()))

elif "messageStop" in event:
# not 100% correct for multiple messages but no way around it
Expand Down
36 changes: 18 additions & 18 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def test_default_inference_params(self, model_name, chat_messages):

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert first_reply.text, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'"
assert "paris" in first_reply.text.lower(), "First reply does not contain 'paris'"
assert first_reply.meta, "First reply has no metadata"

if first_reply.meta and "usage" in first_reply.meta:
Expand Down Expand Up @@ -197,9 +197,9 @@ def streaming_callback(chunk: StreamingChunk):

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert first_reply.text, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'"
assert "paris" in first_reply.text.lower(), "First reply does not contain 'paris'"
assert first_reply.meta, "First reply has no metadata"

@pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS)
Expand Down Expand Up @@ -246,17 +246,17 @@ def test_tools_use(self, model_name):

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert first_reply.text, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert first_reply.meta, "First reply has no metadata"

# Some models return thinking message as first and the second one as the tool call
if len(replies) > 1:
second_reply = replies[1]
assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance"
assert second_reply.content, "Second reply has no content"
assert second_reply.text, "Second reply has no content"
assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant"
tool_call = json.loads(second_reply.content)
tool_call = json.loads(second_reply.text)
assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
Expand All @@ -266,7 +266,7 @@ def test_tools_use(self, model_name):
else:
# case where the model returns the tool call as the first message
# double check that the tool call is correct
tool_call = json.loads(first_reply.content)
tool_call = json.loads(first_reply.text)
assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
Expand Down Expand Up @@ -318,17 +318,17 @@ def test_tools_use_with_streaming(self, model_name):

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert first_reply.text, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert first_reply.meta, "First reply has no metadata"

# Some models return thinking message as first and the second one as the tool call
if len(replies) > 1:
second_reply = replies[1]
assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance"
assert second_reply.content, "Second reply has no content"
assert second_reply.text, "Second reply has no content"
assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant"
tool_call = json.loads(second_reply.content)
tool_call = json.loads(second_reply.text)
assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
Expand All @@ -338,7 +338,7 @@ def test_tools_use_with_streaming(self, model_name):
else:
# case where the model returns the tool call as the first message
# double check that the tool call is correct
tool_call = json.loads(first_reply.content)
tool_call = json.loads(first_reply.text)
assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
Expand All @@ -361,7 +361,7 @@ def test_extract_replies_from_response(self, mock_boto3_session):

replies = generator.extract_replies_from_response(text_response)
assert len(replies) == 1
assert replies[0].content == "This is a test response"
assert replies[0].text == "This is a test response"
assert replies[0].role == ChatRole.ASSISTANT
assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0"
assert replies[0].meta["finish_reason"] == "complete"
Expand All @@ -381,7 +381,7 @@ def test_extract_replies_from_response(self, mock_boto3_session):

replies = generator.extract_replies_from_response(tool_response)
assert len(replies) == 1
tool_content = json.loads(replies[0].content)
tool_content = json.loads(replies[0].text)
assert tool_content["toolUseId"] == "123"
assert tool_content["name"] == "test_tool"
assert tool_content["input"] == {"key": "value"}
Expand All @@ -405,8 +405,8 @@ def test_extract_replies_from_response(self, mock_boto3_session):

replies = generator.extract_replies_from_response(mixed_response)
assert len(replies) == 2
assert replies[0].content == "Let me help you with that. I'll use the search tool to find the answer."
tool_content = json.loads(replies[1].content)
assert replies[0].text == "Let me help you with that. I'll use the search tool to find the answer."
tool_content = json.loads(replies[1].text)
assert tool_content["toolUseId"] == "456"
assert tool_content["name"] == "search_tool"
assert tool_content["input"] == {"query": "test"}
Expand Down Expand Up @@ -446,13 +446,13 @@ def test_callback(chunk: StreamingChunk):
# Verify final replies
assert len(replies) == 2
# Check text reply
assert replies[0].content == "Let me help you."
assert replies[0].text == "Let me help you."
assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0"
assert replies[0].meta["finish_reason"] == "complete"
assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}

# Check tool use reply
tool_content = json.loads(replies[1].content)
tool_content = json.loads(replies[1].text)
assert tool_content["toolUseId"] == "123"
assert tool_content["name"] == "search_tool"
assert tool_content["input"] == {"query": "test"}

0 comments on commit 784fd88

Please sign in to comment.