Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Update Anthropic/Cohere for tools use #790

Merged
merged 6 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import json
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
Expand All @@ -7,13 +8,14 @@

from anthropic import Anthropic, Stream
from anthropic.types import (
ContentBlock,
ContentBlockDeltaEvent,
Message,
MessageDeltaEvent,
MessageStartEvent,
MessageStreamEvent,
TextBlock,
TextDelta,
ToolUseBlock,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -61,6 +63,8 @@ class AnthropicChatGenerator:
# The parameters that can be passed to the Anthropic API https://docs.anthropic.com/claude/reference/messages_post
ALLOWED_PARAMS: ClassVar[List[str]] = [
"system",
"tools",
"tool_choice",
"max_tokens",
"metadata",
"stop_sequences",
Expand All @@ -75,6 +79,7 @@ def __init__(
model: str = "claude-3-sonnet-20240229",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
ignore_tools_thinking_messages: bool = True,
):
"""
Creates an instance of AnthropicChatGenerator.
Expand All @@ -95,13 +100,18 @@ def __init__(
- `temperature`: The temperature to use for sampling.
- `top_p`: The top_p value to use for nucleus sampling.
- `top_k`: The top_k value to use for top-k sampling.

:param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a
"chain of thought" messages before returning the actual function names and parameters in a message. If
ignore_tools_thinking_messages is True, the generator will drop so-called thinking messages when tool
use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use)
for more details.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably this should be indented.
Let's also put include ignore_tools_thinking_messages and True inside backticks

"""
self.api_key = api_key
self.model = model
self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback
self.client = Anthropic(api_key=self.api_key.resolve_value())
self.ignore_tools_thinking_messages = ignore_tools_thinking_messages

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -123,6 +133,7 @@ def to_dict(self) -> Dict[str, Any]:
streaming_callback=callback_name,
generation_kwargs=self.generation_kwargs,
api_key=self.api_key.to_dict(),
ignore_tools_thinking_messages=self.ignore_tools_thinking_messages,
)

@classmethod
Expand Down Expand Up @@ -203,18 +214,24 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
completions = [self._connect_chunks(chunks, start_event, delta)]
# if streaming is disabled, the response is an Anthropic Message
elif isinstance(response, Message):
has_tools_msgs = any(isinstance(content_block, ToolUseBlock) for content_block in response.content)
if has_tools_msgs and self.ignore_tools_thinking_messages:
response.content = [block for block in response.content if isinstance(block, ToolUseBlock)]
Comment on lines +217 to +219
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to better understand...
what happens to tool messages if streaming is enabled?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good question. No support yet for FC under streaming generators. I think we should mention that somewhere and make it across all generators.

completions = [self._build_message(content_block, response) for content_block in response.content]

return {"replies": completions}

def _build_message(self, content_block: ContentBlock, message: Message) -> ChatMessage:
def _build_message(self, content_block: Union[TextBlock, ToolUseBlock], message: Message) -> ChatMessage:
"""
Converts the non-streaming Anthropic Message to a ChatMessage.
:param content_block: The content block of the message.
:param message: The non-streaming Anthropic Message.
:returns: The ChatMessage.
"""
chat_message = ChatMessage.from_assistant(content_block.text)
if isinstance(content_block, TextBlock):
chat_message = ChatMessage.from_assistant(content_block.text)
else:
chat_message = ChatMessage.from_assistant(json.dumps(content_block.model_dump(mode="json")))
chat_message.meta.update(
{
"model": message.model,
Expand Down
46 changes: 46 additions & 0 deletions integrations/anthropic/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os

import anthropic
Expand Down Expand Up @@ -25,6 +26,7 @@ def test_init_default(self, monkeypatch):
assert component.model == "claude-3-sonnet-20240229"
assert component.streaming_callback is None
assert not component.generation_kwargs
assert component.ignore_tools_thinking_messages

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
Expand All @@ -37,11 +39,13 @@ def test_init_with_parameters(self):
model="claude-3-sonnet-20240229",
streaming_callback=print_streaming_chunk,
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
ignore_tools_thinking_messages=False,
)
assert component.client.api_key == "test-api-key"
assert component.model == "claude-3-sonnet-20240229"
assert component.streaming_callback is print_streaming_chunk
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
assert component.ignore_tools_thinking_messages is False

def test_to_dict_default(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key")
Expand All @@ -54,6 +58,7 @@ def test_to_dict_default(self, monkeypatch):
"model": "claude-3-sonnet-20240229",
"streaming_callback": None,
"generation_kwargs": {},
"ignore_tools_thinking_messages": True,
},
}

Expand All @@ -72,6 +77,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
"model": "claude-3-sonnet-20240229",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"ignore_tools_thinking_messages": True,
},
}

Expand All @@ -90,6 +96,7 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch):
"model": "claude-3-sonnet-20240229",
"streaming_callback": "tests.test_chat_generator.<lambda>",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"ignore_tools_thinking_messages": True,
},
}

Expand All @@ -102,6 +109,7 @@ def test_from_dict(self, monkeypatch):
"model": "claude-3-sonnet-20240229",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"ignore_tools_thinking_messages": True,
},
}
component = AnthropicChatGenerator.from_dict(data)
Expand All @@ -119,6 +127,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
"model": "claude-3-sonnet-20240229",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"ignore_tools_thinking_messages": True,
},
}
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
Expand Down Expand Up @@ -216,3 +225,40 @@ def streaming_callback(chunk: StreamingChunk):
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'"
assert first_reply.meta, "First reply has no metadata"

@pytest.mark.skipif(
not os.environ.get("ANTHROPIC_API_KEY", None),
reason="Export an env var called ANTHROPIC_API_KEY containing the Anthropic API key to run this test.",
)
@pytest.mark.integration
def test_tools_use(self):
# See https://docs.anthropic.com/en/docs/tool-use for more information
tools_schema = {
"name": "get_stock_price",
"description": "Retrieves the current stock price for a given ticker symbol.",
"input_schema": {
"type": "object",
"properties": {
"ticker": {"type": "string", "description": "The stock ticker symbol, e.g. AAPL for Apple Inc."}
},
"required": ["ticker"],
},
}
client = AnthropicChatGenerator()
response = client.run(
messages=[ChatMessage.from_user("What is the current price of AAPL?")],
generation_kwargs={"tools": [tools_schema]},
)
replies = response["replies"]
assert isinstance(replies, list), "Replies is not a list"
assert len(replies) > 0, "No replies received"

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert "get_stock_price" in first_reply.content.lower(), "First reply does not contain get_stock_price"
assert first_reply.meta, "First reply has no metadata"
fc_response = json.loads(first_reply.content)
assert "name" in fc_response, "First reply does not contain name of the tool"
assert "input" in fc_response, "First reply does not contain input of the tool"
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,12 @@ def _build_message(self, cohere_response):
:param cohere_response: The completion returned by the Cohere API.
:returns: The ChatMessage.
"""
content = cohere_response.text
message = ChatMessage.from_assistant(content=content)

message = None
if cohere_response.tool_calls:
# TODO revisit to see if we need to handle multiple tool calls
message = ChatMessage.from_assistant(cohere_response.tool_calls[0].json())
elif cohere_response.text:
message = ChatMessage.from_assistant(content=cohere_response.text)
total_tokens = cohere_response.meta.billed_units.input_tokens + cohere_response.meta.billed_units.output_tokens
message.meta.update(
{
Expand Down
38 changes: 38 additions & 0 deletions integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from unittest.mock import Mock

Expand Down Expand Up @@ -254,3 +255,40 @@ def __call__(self, chunk: StreamingChunk) -> None:

assert message.meta["documents"] is not None
assert message.meta["citations"] is not None

@pytest.mark.skipif(
not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None),
reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.",
)
@pytest.mark.integration
def test_tools_use(self):
# See https://docs.anthropic.com/en/docs/tool-use for more information
tools_schema = {
"name": "get_stock_price",
"description": "Retrieves the current stock price for a given ticker symbol.",
"parameter_definitions": {
"ticker": {
"type": "string",
"description": "The stock ticker symbol, e.g. AAPL for Apple Inc.",
"required": True,
}
},
}
client = CohereChatGenerator(model="command-r")
response = client.run(
messages=[ChatMessage.from_user("What is the current price of AAPL?")],
generation_kwargs={"tools": [tools_schema]},
)
replies = response["replies"]
assert isinstance(replies, list), "Replies is not a list"
assert len(replies) > 0, "No replies received"

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert "get_stock_price" in first_reply.content.lower(), "First reply does not contain get_stock_price"
assert first_reply.meta, "First reply has no metadata"
fc_response = json.loads(first_reply.content)
assert "name" in fc_response, "First reply does not contain name of the tool"
assert "parameters" in fc_response, "First reply does not contain parameters of the tool"
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch):
"type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator",
"init_parameters": {
"model": "command",
"streaming_callback": "tests.test_cohere_generators.<lambda>",
"streaming_callback": "tests.test_cohere_generator.<lambda>",
"api_base_url": "test-base-url",
"api_key": {"type": "env_var", "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True},
"generation_kwargs": {},
Expand Down