diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index d1bb71e99..4f1a5bef6 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -1,3 +1,4 @@ +import json import os import anthropic @@ -25,7 +26,7 @@ def test_init_default(self, monkeypatch): assert component.model == "claude-3-sonnet-20240229" assert component.streaming_callback is None assert not component.generation_kwargs - assert component.filter_thinking_for_tool_use + assert component.ignore_tools_thinking_messages def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) @@ -38,13 +39,13 @@ def test_init_with_parameters(self): model="claude-3-sonnet-20240229", streaming_callback=print_streaming_chunk, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - filter_thinking_for_tool_use=False, + ignore_tools_thinking_messages=False, ) assert component.client.api_key == "test-api-key" assert component.model == "claude-3-sonnet-20240229" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} - assert component.filter_thinking_for_tool_use is False + assert component.ignore_tools_thinking_messages is False def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") @@ -57,6 +58,7 @@ def test_to_dict_default(self, monkeypatch): "model": "claude-3-sonnet-20240229", "streaming_callback": None, "generation_kwargs": {}, + "ignore_tools_thinking_messages": True, }, } @@ -75,6 +77,7 @@ def test_to_dict_with_parameters(self, monkeypatch): "model": "claude-3-sonnet-20240229", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, }, } @@ -93,6 +96,7 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): "model": "claude-3-sonnet-20240229", "streaming_callback": "tests.test_chat_generator.<lambda>", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, }, } @@ -105,6 +109,7 @@ def test_from_dict(self, monkeypatch): "model": "claude-3-sonnet-20240229", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, }, } component = AnthropicChatGenerator.from_dict(data) @@ -122,6 +127,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): "model": "claude-3-sonnet-20240229", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, }, } with pytest.raises(ValueError, match="None of the .* environment variables are set"): @@ -219,3 +225,40 @@ def streaming_callback(chunk: StreamingChunk): assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" + + @pytest.mark.skipif( + not os.environ.get("ANTHROPIC_API_KEY", None), + reason="Export an env var called ANTHROPIC_API_KEY containing the Anthropic API key to run this test.", + ) + @pytest.mark.integration + def test_tools_use(self): + # See https://docs.anthropic.com/en/docs/tool-use for more information + tools_schema = { + "name": "get_stock_price", + "description": "Retrieves the current stock price for a given ticker symbol.", + "input_schema": { + "type": "object", + "properties": { + "ticker": {"type": "string", "description": "The stock ticker symbol, e.g. AAPL for Apple Inc."} + }, + "required": ["ticker"], + }, + } + client = AnthropicChatGenerator() + response = client.run( + messages=[ChatMessage.from_user("What is the current price of AAPL?")], + generation_kwargs={"tools": [tools_schema]}, + ) + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "get_stock_price" in first_reply.content.lower(), "First reply does not contain get_stock_price" + assert first_reply.meta, "First reply has no metadata" + fc_response = json.loads(first_reply.content) + assert "name" in fc_response, "First reply does not contain name of the tool" + assert "input" in fc_response, "First reply does not contain input of the tool" diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index b0dcf8e8a..80a40bbfe 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -1,3 +1,4 @@ +import json import os from unittest.mock import Mock @@ -254,3 +255,40 @@ def __call__(self, chunk: StreamingChunk) -> None: assert message.meta["documents"] is not None assert message.meta["citations"] is not None + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_tools_use(self): + # See https://docs.anthropic.com/en/docs/tool-use for more information + tools_schema = { + "name": "get_stock_price", + "description": "Retrieves the current stock price for a given ticker symbol.", + "parameter_definitions": { + "ticker": { + "type": "string", + "description": "The stock ticker symbol, e.g. AAPL for Apple Inc.", + "required": True, + } + }, + } + client = CohereChatGenerator(model="command-r") + response = client.run( + messages=[ChatMessage.from_user("What is the current price of AAPL?")], + generation_kwargs={"tools": [tools_schema]}, + ) + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "get_stock_price" in first_reply.content.lower(), "First reply does not contain get_stock_price" + assert first_reply.meta, "First reply has no metadata" + fc_response = json.loads(first_reply.content) + assert "name" in fc_response, "First reply does not contain name of the tool" + assert "parameters" in fc_response, "First reply does not contain parameters of the tool"