Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixes to Bedrock Chat Generator for compatibility with the new ChatMessage #1250

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,10 @@ def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[C
# Process each content block separately
for content_block in content_blocks:
if "text" in content_block:
replies.append(ChatMessage.from_assistant(content=content_block["text"], meta=base_meta.copy()))
replies.append(ChatMessage.from_assistant(content_block["text"], meta=base_meta.copy()))
elif "toolUse" in content_block:
replies.append(
ChatMessage.from_assistant(
content=json.dumps(content_block["toolUse"]), meta=base_meta.copy()
)
ChatMessage.from_assistant(json.dumps(content_block["toolUse"]), meta=base_meta.copy())
)
return replies

Expand Down Expand Up @@ -334,9 +332,9 @@ def process_streaming_response(
pass

tool_content = json.dumps(current_tool_use)
replies.append(ChatMessage.from_assistant(content=tool_content, meta=base_meta.copy()))
replies.append(ChatMessage.from_assistant(tool_content, meta=base_meta.copy()))
elif current_content:
replies.append(ChatMessage.from_assistant(content=current_content, meta=base_meta.copy()))
replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy()))

elif "messageStop" in event:
# not 100% correct for multiple messages but no way around it
Expand Down
36 changes: 18 additions & 18 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def test_default_inference_params(self, model_name, chat_messages):

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert first_reply.text, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'"
assert "paris" in first_reply.text.lower(), "First reply does not contain 'paris'"
assert first_reply.meta, "First reply has no metadata"

if first_reply.meta and "usage" in first_reply.meta:
Expand Down Expand Up @@ -197,9 +197,9 @@ def streaming_callback(chunk: StreamingChunk):

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert first_reply.text, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'"
assert "paris" in first_reply.text.lower(), "First reply does not contain 'paris'"
assert first_reply.meta, "First reply has no metadata"

@pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS)
Expand Down Expand Up @@ -246,17 +246,17 @@ def test_tools_use(self, model_name):

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert first_reply.text, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert first_reply.meta, "First reply has no metadata"

# Some models return thinking message as first and the second one as the tool call
if len(replies) > 1:
second_reply = replies[1]
assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance"
assert second_reply.content, "Second reply has no content"
assert second_reply.text, "Second reply has no content"
assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant"
tool_call = json.loads(second_reply.content)
tool_call = json.loads(second_reply.text)
assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
Expand All @@ -266,7 +266,7 @@ def test_tools_use(self, model_name):
else:
# case where the model returns the tool call as the first message
# double check that the tool call is correct
tool_call = json.loads(first_reply.content)
tool_call = json.loads(first_reply.text)
assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
Expand Down Expand Up @@ -318,17 +318,17 @@ def test_tools_use_with_streaming(self, model_name):

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert first_reply.text, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert first_reply.meta, "First reply has no metadata"

# Some models return thinking message as first and the second one as the tool call
if len(replies) > 1:
second_reply = replies[1]
assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance"
assert second_reply.content, "Second reply has no content"
assert second_reply.text, "Second reply has no content"
assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant"
tool_call = json.loads(second_reply.content)
tool_call = json.loads(second_reply.text)
assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
Expand All @@ -338,7 +338,7 @@ def test_tools_use_with_streaming(self, model_name):
else:
# case where the model returns the tool call as the first message
# double check that the tool call is correct
tool_call = json.loads(first_reply.content)
tool_call = json.loads(first_reply.text)
assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
Expand All @@ -361,7 +361,7 @@ def test_extract_replies_from_response(self, mock_boto3_session):

replies = generator.extract_replies_from_response(text_response)
assert len(replies) == 1
assert replies[0].content == "This is a test response"
assert replies[0].text == "This is a test response"
assert replies[0].role == ChatRole.ASSISTANT
assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0"
assert replies[0].meta["finish_reason"] == "complete"
Expand All @@ -381,7 +381,7 @@ def test_extract_replies_from_response(self, mock_boto3_session):

replies = generator.extract_replies_from_response(tool_response)
assert len(replies) == 1
tool_content = json.loads(replies[0].content)
tool_content = json.loads(replies[0].text)
assert tool_content["toolUseId"] == "123"
assert tool_content["name"] == "test_tool"
assert tool_content["input"] == {"key": "value"}
Expand All @@ -405,8 +405,8 @@ def test_extract_replies_from_response(self, mock_boto3_session):

replies = generator.extract_replies_from_response(mixed_response)
assert len(replies) == 2
assert replies[0].content == "Let me help you with that. I'll use the search tool to find the answer."
tool_content = json.loads(replies[1].content)
assert replies[0].text == "Let me help you with that. I'll use the search tool to find the answer."
tool_content = json.loads(replies[1].text)
assert tool_content["toolUseId"] == "456"
assert tool_content["name"] == "search_tool"
assert tool_content["input"] == {"query": "test"}
Expand Down Expand Up @@ -446,13 +446,13 @@ def test_callback(chunk: StreamingChunk):
# Verify final replies
assert len(replies) == 2
# Check text reply
assert replies[0].content == "Let me help you."
assert replies[0].text == "Let me help you."
assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0"
assert replies[0].meta["finish_reason"] == "complete"
assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}

# Check tool use reply
tool_content = json.loads(replies[1].content)
tool_content = json.loads(replies[1].text)
assert tool_content["toolUseId"] == "123"
assert tool_content["name"] == "search_tool"
assert tool_content["input"] == {"query": "test"}
Loading