Skip to content

Commit

Permalink
add serde test for toolinvoker
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Dec 20, 2024
1 parent d279ab2 commit 7130d02
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions test/components/tools/test_tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole
from haystack.dataclasses.tool import Tool, ToolInvocationError
from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError
from haystack.components.generators.chat.openai import OpenAIChatGenerator


def weather_function(location):
Expand Down Expand Up @@ -218,3 +219,59 @@ def test_from_dict(self, weather_tool):
assert invoker._tools_with_names == {"weather_tool": weather_tool}
assert invoker.raise_on_failure
assert not invoker.convert_result_to_json_string

def test_serde_in_pipeline(self, invoker, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key")

pipeline = Pipeline()
pipeline.add_component("invoker", invoker)
pipeline.add_component("chatgenerator", OpenAIChatGenerator())
pipeline.connect("invoker", "chatgenerator")

pipeline_dict = pipeline.to_dict()
assert pipeline_dict == {
"metadata": {},
"max_runs_per_component": 100,
"components": {
"invoker": {
"type": "haystack.components.tools.tool_invoker.ToolInvoker",
"init_parameters": {
"tools": [
{
"name": "weather_tool",
"description": "Provides weather information for a given location.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
"function": "tools.test_tool_invoker.weather_function",
}
],
"raise_on_failure": True,
"convert_result_to_json_string": False,
},
},
"chatgenerator": {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model": "gpt-4o-mini",
"streaming_callback": None,
"api_base_url": None,
"organization": None,
"generation_kwargs": {},
"max_retries": None,
"timeout": None,
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
"tools": None,
"tools_strict": False,
},
},
},
"connections": [{"sender": "invoker.tool_messages", "receiver": "chatgenerator.messages"}],
}

pipeline_yaml = pipeline.dumps()

new_pipeline = Pipeline.loads(pipeline_yaml)
assert new_pipeline == pipeline

0 comments on commit 7130d02

Please sign in to comment.