Skip to content

Commit

Permalink
fix(Bedrock): allow tools kwargs for AWS Bedrock Claude model (#976)
Browse files Browse the repository at this point in the history
  • Loading branch information
lambda-science authored Aug 20, 2024
1 parent 39decab commit 3085cf9
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class AnthropicClaudeChatAdapter(BedrockModelChatAdapter):
"top_p",
"top_k",
"system",
"tools",
"tool_choice",
]

def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]):
Expand Down Expand Up @@ -253,10 +255,18 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List
"""
messages: List[ChatMessage] = []
if response_body.get("type") == "message":
for content in response_body["content"]:
if content.get("type") == "text":
meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]}
messages.append(ChatMessage.from_assistant(content["text"], meta=meta))
if response_body.get("stop_reason") == "tool_use": # If `tool_use` we only keep the tool_use content
for content in response_body["content"]:
if content.get("type") == "tool_use":
meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]}
json_answer = json.dumps(content)
messages.append(ChatMessage.from_assistant(json_answer, meta=meta))
else: # For other stop_reason, return all text content
for content in response_body["content"]:
if content.get("type") == "text":
meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]}
messages.append(ChatMessage.from_assistant(content["text"], meta=meta))

return messages

def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
Expand Down
44 changes: 44 additions & 0 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os
from typing import Optional, Type
Expand All @@ -17,6 +18,7 @@

KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator"
MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"]
MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-haiku-20240307-v1:0"]
MISTRAL_MODELS = [
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
Expand Down Expand Up @@ -303,6 +305,48 @@ def test_prepare_body_with_custom_inference_params(self) -> None:

assert body == expected_body

@pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS)
@pytest.mark.integration
def test_tools_use(self, model_name):
"""
Test function calling with AWS Bedrock Anthropic adapter
"""
# See https://docs.anthropic.com/en/docs/tool-use for more information
tools = [
{
"name": "top_song",
"description": "Get the most popular song played on a radio station.",
"input_schema": {
"type": "object",
"properties": {
"sign": {
"type": "string",
"description": "The call sign for the radio station for which you want the most popular"
" song. Example calls signs are WZPZ and WKRP.",
}
},
"required": ["sign"],
},
}
]
messages = []
messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?"))
client = AmazonBedrockChatGenerator(model=model_name)
response = client.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": {"type": "any"}})
replies = response["replies"]
assert isinstance(replies, list), "Replies is not a list"
assert len(replies) > 0, "No replies received"

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert "top_song" in first_reply.content.lower(), "First reply does not contain top_song"
assert first_reply.meta, "First reply has no metadata"
fc_response = json.loads(first_reply.content)
assert "name" in fc_response, "First reply does not contain name of the tool"
assert "input" in fc_response, "First reply does not contain input of the tool"


class TestMistralAdapter:
def test_prepare_body_with_default_params(self) -> None:
Expand Down

0 comments on commit 3085cf9

Please sign in to comment.