Skip to content

Commit

Permalink
Add tools usage integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Jun 17, 2024
1 parent a0f5acf commit dffa818
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 3 deletions.
49 changes: 46 additions & 3 deletions integrations/anthropic/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os

import anthropic
Expand Down Expand Up @@ -25,7 +26,7 @@ def test_init_default(self, monkeypatch):
assert component.model == "claude-3-sonnet-20240229"
assert component.streaming_callback is None
assert not component.generation_kwargs
assert component.filter_thinking_for_tool_use
assert component.ignore_tools_thinking_messages

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
Expand All @@ -38,13 +39,13 @@ def test_init_with_parameters(self):
model="claude-3-sonnet-20240229",
streaming_callback=print_streaming_chunk,
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
filter_thinking_for_tool_use=False,
ignore_tools_thinking_messages=False,
)
assert component.client.api_key == "test-api-key"
assert component.model == "claude-3-sonnet-20240229"
assert component.streaming_callback is print_streaming_chunk
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
assert component.filter_thinking_for_tool_use is False
assert component.ignore_tools_thinking_messages is False

def test_to_dict_default(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key")
Expand All @@ -57,6 +58,7 @@ def test_to_dict_default(self, monkeypatch):
"model": "claude-3-sonnet-20240229",
"streaming_callback": None,
"generation_kwargs": {},
"ignore_tools_thinking_messages": True,
},
}

Expand All @@ -75,6 +77,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
"model": "claude-3-sonnet-20240229",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"ignore_tools_thinking_messages": True,
},
}

Expand All @@ -93,6 +96,7 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch):
"model": "claude-3-sonnet-20240229",
"streaming_callback": "tests.test_chat_generator.<lambda>",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"ignore_tools_thinking_messages": True,
},
}

Expand All @@ -105,6 +109,7 @@ def test_from_dict(self, monkeypatch):
"model": "claude-3-sonnet-20240229",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"ignore_tools_thinking_messages": True,
},
}
component = AnthropicChatGenerator.from_dict(data)
Expand All @@ -122,6 +127,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
"model": "claude-3-sonnet-20240229",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"ignore_tools_thinking_messages": True,
},
}
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
Expand Down Expand Up @@ -219,3 +225,40 @@ def streaming_callback(chunk: StreamingChunk):
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'"
assert first_reply.meta, "First reply has no metadata"

@pytest.mark.skipif(
not os.environ.get("ANTHROPIC_API_KEY", None),
reason="Export an env var called ANTHROPIC_API_KEY containing the Anthropic API key to run this test.",
)
@pytest.mark.integration
def test_tools_use(self):
# See https://docs.anthropic.com/en/docs/tool-use for more information
tools_schema = {
"name": "get_stock_price",
"description": "Retrieves the current stock price for a given ticker symbol.",
"input_schema": {
"type": "object",
"properties": {
"ticker": {"type": "string", "description": "The stock ticker symbol, e.g. AAPL for Apple Inc."}
},
"required": ["ticker"],
},
}
client = AnthropicChatGenerator()
response = client.run(
messages=[ChatMessage.from_user("What is the current price of AAPL?")],
generation_kwargs={"tools": [tools_schema]},
)
replies = response["replies"]
assert isinstance(replies, list), "Replies is not a list"
assert len(replies) > 0, "No replies received"

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert "get_stock_price" in first_reply.content.lower(), "First reply does not contain get_stock_price"
assert first_reply.meta, "First reply has no metadata"
fc_response = json.loads(first_reply.content)
assert "name" in fc_response, "First reply does not contain name of the tool"
assert "input" in fc_response, "First reply does not contain input of the tool"
38 changes: 38 additions & 0 deletions integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from unittest.mock import Mock

Expand Down Expand Up @@ -254,3 +255,40 @@ def __call__(self, chunk: StreamingChunk) -> None:

assert message.meta["documents"] is not None
assert message.meta["citations"] is not None

@pytest.mark.skipif(
not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None),
reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.",
)
@pytest.mark.integration
def test_tools_use(self):
# See https://docs.anthropic.com/en/docs/tool-use for more information
tools_schema = {
"name": "get_stock_price",
"description": "Retrieves the current stock price for a given ticker symbol.",
"parameter_definitions": {
"ticker": {
"type": "string",
"description": "The stock ticker symbol, e.g. AAPL for Apple Inc.",
"required": True,
}
},
}
client = CohereChatGenerator(model="command-r")
response = client.run(
messages=[ChatMessage.from_user("What is the current price of AAPL?")],
generation_kwargs={"tools": [tools_schema]},
)
replies = response["replies"]
assert isinstance(replies, list), "Replies is not a list"
assert len(replies) > 0, "No replies received"

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert "get_stock_price" in first_reply.content.lower(), "First reply does not contain get_stock_price"
assert first_reply.meta, "First reply has no metadata"
fc_response = json.loads(first_reply.content)
assert "name" in fc_response, "First reply does not contain name of the tool"
assert "parameters" in fc_response, "First reply does not contain parameters of the tool"

0 comments on commit dffa818

Please sign in to comment.