From 96b4a1d2fd82f3e072060d8c11ae3e1fc230d681 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 18 Dec 2024 12:36:44 +0100 Subject: [PATCH] feat: `Tool` dataclass - unified abstraction to represent tools (#8652) * draft * del HF token in tests * adaptations * progress * fix type * import sorting * more control on deserialization * release note * improvements * support name field * fix chatpromptbuilder test * port Tool from experimental * release note * docs upd * Update tool.py --------- Co-authored-by: Daria Fokina --- docs/pydoc/config/data_classess_api.yml | 2 +- haystack/dataclasses/__init__.py | 2 + haystack/dataclasses/tool.py | 243 ++++++++++++++ pyproject.toml | 3 +- .../tool-dataclass-12756077bbfea3a1.yaml | 8 + test/dataclasses/test_tool.py | 305 ++++++++++++++++++ 6 files changed, 561 insertions(+), 2 deletions(-) create mode 100644 haystack/dataclasses/tool.py create mode 100644 releasenotes/notes/tool-dataclass-12756077bbfea3a1.yaml create mode 100644 test/dataclasses/test_tool.py diff --git a/docs/pydoc/config/data_classess_api.yml b/docs/pydoc/config/data_classess_api.yml index a67f28db9d..71ea77513a 100644 --- a/docs/pydoc/config/data_classess_api.yml +++ b/docs/pydoc/config/data_classess_api.yml @@ -2,7 +2,7 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../haystack/dataclasses] modules: - ["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding"] + ["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding", "tool"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack/dataclasses/__init__.py b/haystack/dataclasses/__init__.py index 91e8f0408f..97f253e805 100644 --- a/haystack/dataclasses/__init__.py +++ b/haystack/dataclasses/__init__.py @@ -8,6 +8,7 @@ from haystack.dataclasses.document import Document from haystack.dataclasses.sparse_embedding import SparseEmbedding from haystack.dataclasses.streaming_chunk import StreamingChunk +from haystack.dataclasses.tool import Tool __all__ = [ "Document", @@ -22,4 +23,5 @@ "TextContent", "StreamingChunk", "SparseEmbedding", + "Tool", ] diff --git a/haystack/dataclasses/tool.py b/haystack/dataclasses/tool.py new file mode 100644 index 0000000000..3df3fd18f2 --- /dev/null +++ b/haystack/dataclasses/tool.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import inspect +from dataclasses import asdict, dataclass +from typing import Any, Callable, Dict, Optional + +from pydantic import create_model + +from haystack.lazy_imports import LazyImport +from haystack.utils import deserialize_callable, serialize_callable + +with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import: + from jsonschema import Draft202012Validator + from jsonschema.exceptions import SchemaError + + +class ToolInvocationError(Exception): + """ + Exception raised when a Tool invocation fails. + """ + + pass + + +class SchemaGenerationError(Exception): + """ + Exception raised when automatic schema generation fails. + """ + + pass + + +@dataclass +class Tool: + """ + Data class representing a Tool that Language Models can prepare a call for. + + Accurate definitions of the textual attributes such as `name` and `description` + are important for the Language Model to correctly prepare the call. + + :param name: + Name of the Tool. + :param description: + Description of the Tool. + :param parameters: + A JSON schema defining the parameters expected by the Tool. + :param function: + The function that will be invoked when the Tool is called. + """ + + name: str + description: str + parameters: Dict[str, Any] + function: Callable + + def __post_init__(self): + jsonschema_import.check() + # Check that the parameters define a valid JSON schema + try: + Draft202012Validator.check_schema(self.parameters) + except SchemaError as e: + raise ValueError("The provided parameters do not define a valid JSON schema") from e + + @property + def tool_spec(self) -> Dict[str, Any]: + """ + Return the Tool specification to be used by the Language Model. + """ + return {"name": self.name, "description": self.description, "parameters": self.parameters} + + def invoke(self, **kwargs) -> Any: + """ + Invoke the Tool with the provided keyword arguments. + """ + + try: + result = self.function(**kwargs) + except Exception as e: + raise ToolInvocationError(f"Failed to invoke Tool `{self.name}` with parameters {kwargs}") from e + return result + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the Tool to a dictionary. + + :returns: + Dictionary with serialized data. + """ + + serialized = asdict(self) + serialized["function"] = serialize_callable(self.function) + return serialized + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Tool": + """ + Deserializes the Tool from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized Tool. + """ + data["function"] = deserialize_callable(data["function"]) + return cls(**data) + + @classmethod + def from_function(cls, function: Callable, name: Optional[str] = None, description: Optional[str] = None) -> "Tool": + """ + Create a Tool instance from a function. + + ### Usage example + + ```python + from typing import Annotated, Literal + from haystack.dataclasses import Tool + + def get_weather( + city: Annotated[str, "the city for which to get the weather"] = "Munich", + unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius"): + '''A simple function to get the current weather for a location.''' + return f"Weather report for {city}: 20 {unit}, sunny" + + tool = Tool.from_function(get_weather) + + print(tool) + >>> Tool(name='get_weather', description='A simple function to get the current weather for a location.', + >>> parameters={ + >>> 'type': 'object', + >>> 'properties': { + >>> 'city': {'type': 'string', 'description': 'the city for which to get the weather', 'default': 'Munich'}, + >>> 'unit': { + >>> 'type': 'string', + >>> 'enum': ['Celsius', 'Fahrenheit'], + >>> 'description': 'the unit for the temperature', + >>> 'default': 'Celsius', + >>> }, + >>> } + >>> }, + >>> function=) + ``` + + :param function: + The function to be converted into a Tool. + The function must include type hints for all parameters. + If a parameter is annotated using `typing.Annotated`, its metadata will be used as parameter description. + :param name: + The name of the Tool. If not provided, the name of the function will be used. + :param description: + The description of the Tool. If not provided, the docstring of the function will be used. + To intentionally leave the description empty, pass an empty string. + + :returns: + The Tool created from the function. + + :raises ValueError: + If any parameter of the function lacks a type hint. + :raises SchemaGenerationError: + If there is an error generating the JSON schema for the Tool. + """ + + tool_description = description if description is not None else (function.__doc__ or "") + + signature = inspect.signature(function) + + # collect fields (types and defaults) and descriptions from function parameters + fields: Dict[str, Any] = {} + descriptions = {} + + for param_name, param in signature.parameters.items(): + if param.annotation is param.empty: + raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.") + + # if the parameter has not a default value, Pydantic requires an Ellipsis (...) + # to explicitly indicate that the parameter is required + default = param.default if param.default is not param.empty else ... + fields[param_name] = (param.annotation, default) + + if hasattr(param.annotation, "__metadata__"): + descriptions[param_name] = param.annotation.__metadata__[0] + + # create Pydantic model and generate JSON schema + try: + model = create_model(function.__name__, **fields) + schema = model.model_json_schema() + except Exception as e: + raise SchemaGenerationError(f"Failed to create JSON schema for function '{function.__name__}'") from e + + # we don't want to include title keywords in the schema, as they contain redundant information + # there is no programmatic way to prevent Pydantic from adding them, so we remove them later + # see https://github.com/pydantic/pydantic/discussions/8504 + _remove_title_from_schema(schema) + + # add parameters descriptions to the schema + for param_name, param_description in descriptions.items(): + if param_name in schema["properties"]: + schema["properties"][param_name]["description"] = param_description + + return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function) + + +def _remove_title_from_schema(schema: Dict[str, Any]): + """ + Remove the 'title' keyword from JSON schema and contained property schemas. + + :param schema: + The JSON schema to remove the 'title' keyword from. + """ + schema.pop("title", None) + + for property_schema in schema["properties"].values(): + for key in list(property_schema.keys()): + if key == "title": + del property_schema[key] + + +def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"): + """ + Deserialize Tools in a dictionary inplace. + + :param data: + The dictionary with the serialized data. + :param key: + The key in the dictionary where the Tools are stored. + """ + if key in data: + serialized_tools = data[key] + + if serialized_tools is None: + return + + if not isinstance(serialized_tools, list): + raise TypeError(f"The value of '{key}' is not a list") + + deserialized_tools = [] + for tool in serialized_tools: + if not isinstance(tool, dict): + raise TypeError(f"Serialized tool '{tool}' is not a dictionary") + deserialized_tools.append(Tool.from_dict(tool)) + + data[key] = deserialized_tools diff --git a/pyproject.toml b/pyproject.toml index c41c429ced..c1fddc8704 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "tenacity!=8.4.0", "lazy-imports", "openai>=1.56.1", + "pydantic", "Jinja2", "posthog", # telemetry "pyyaml", @@ -113,7 +114,7 @@ extra-dependencies = [ "jsonref", # OpenAPIServiceConnector, OpenAPIServiceToFunctions "openapi3", - # Validation + # JsonSchemaValidator, Tool "jsonschema", # Tracing diff --git a/releasenotes/notes/tool-dataclass-12756077bbfea3a1.yaml b/releasenotes/notes/tool-dataclass-12756077bbfea3a1.yaml new file mode 100644 index 0000000000..b6255ee1a9 --- /dev/null +++ b/releasenotes/notes/tool-dataclass-12756077bbfea3a1.yaml @@ -0,0 +1,8 @@ +--- +highlights: > + We are introducing the `Tool` dataclass: a simple and unified abstraction to represent tools throughout the framework. + By building on this abstraction, we will enable support for tools in Chat Generators, + providing a consistent experience across models. +features: + - | + Added a new `Tool` dataclass to represent a tool for which Language Models can prepare calls. diff --git a/test/dataclasses/test_tool.py b/test/dataclasses/test_tool.py new file mode 100644 index 0000000000..db9719a7f3 --- /dev/null +++ b/test/dataclasses/test_tool.py @@ -0,0 +1,305 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Literal, Optional + +import pytest + +from haystack.dataclasses.tool import ( + SchemaGenerationError, + Tool, + ToolInvocationError, + _remove_title_from_schema, + deserialize_tools_inplace, +) + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + + +def get_weather_report(city: str) -> str: + return f"Weather report for {city}: 20°C, sunny" + + +parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + + +def function_with_docstring(city: str) -> str: + """Get weather report for a city.""" + return f"Weather report for {city}: 20°C, sunny" + + +class TestTool: + def test_init(self): + tool = Tool( + name="weather", description="Get weather report", parameters=parameters, function=get_weather_report + ) + + assert tool.name == "weather" + assert tool.description == "Get weather report" + assert tool.parameters == parameters + assert tool.function == get_weather_report + + def test_init_invalid_parameters(self): + parameters = {"type": "invalid", "properties": {"city": {"type": "string"}}} + + with pytest.raises(ValueError): + Tool(name="irrelevant", description="irrelevant", parameters=parameters, function=get_weather_report) + + def test_tool_spec(self): + tool = Tool( + name="weather", description="Get weather report", parameters=parameters, function=get_weather_report + ) + + assert tool.tool_spec == {"name": "weather", "description": "Get weather report", "parameters": parameters} + + def test_invoke(self): + tool = Tool( + name="weather", description="Get weather report", parameters=parameters, function=get_weather_report + ) + + assert tool.invoke(city="Berlin") == "Weather report for Berlin: 20°C, sunny" + + def test_invoke_fail(self): + tool = Tool( + name="weather", description="Get weather report", parameters=parameters, function=get_weather_report + ) + + with pytest.raises(ToolInvocationError): + tool.invoke() + + def test_to_dict(self): + tool = Tool( + name="weather", description="Get weather report", parameters=parameters, function=get_weather_report + ) + + assert tool.to_dict() == { + "name": "weather", + "description": "Get weather report", + "parameters": parameters, + "function": "test_tool.get_weather_report", + } + + def test_from_dict(self): + tool_dict = { + "name": "weather", + "description": "Get weather report", + "parameters": parameters, + "function": "test_tool.get_weather_report", + } + + tool = Tool.from_dict(tool_dict) + + assert tool.name == "weather" + assert tool.description == "Get weather report" + assert tool.parameters == parameters + assert tool.function == get_weather_report + + def test_from_function_description_from_docstring(self): + tool = Tool.from_function(function=function_with_docstring) + + assert tool.name == "function_with_docstring" + assert tool.description == "Get weather report for a city." + assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + assert tool.function == function_with_docstring + + def test_from_function_with_empty_description(self): + tool = Tool.from_function(function=function_with_docstring, description="") + + assert tool.name == "function_with_docstring" + assert tool.description == "" + assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + assert tool.function == function_with_docstring + + def test_from_function_with_custom_description(self): + tool = Tool.from_function(function=function_with_docstring, description="custom description") + + assert tool.name == "function_with_docstring" + assert tool.description == "custom description" + assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + assert tool.function == function_with_docstring + + def test_from_function_with_custom_name(self): + tool = Tool.from_function(function=function_with_docstring, name="custom_name") + + assert tool.name == "custom_name" + assert tool.description == "Get weather report for a city." + assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + assert tool.function == function_with_docstring + + def test_from_function_missing_type_hint(self): + def function_missing_type_hint(city) -> str: + return f"Weather report for {city}: 20°C, sunny" + + with pytest.raises(ValueError): + Tool.from_function(function=function_missing_type_hint) + + def test_from_function_schema_generation_error(self): + def function_with_invalid_type_hint(city: "invalid") -> str: + return f"Weather report for {city}: 20°C, sunny" + + with pytest.raises(SchemaGenerationError): + Tool.from_function(function=function_with_invalid_type_hint) + + def test_from_function_annotated(self): + def function_with_annotations( + city: Annotated[str, "the city for which to get the weather"] = "Munich", + unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius", + nullable_param: Annotated[Optional[str], "a nullable parameter"] = None, + ) -> str: + """A simple function to get the current weather for a location.""" + return f"Weather report for {city}: 20 {unit}, sunny" + + tool = Tool.from_function(function=function_with_annotations) + + assert tool.name == "function_with_annotations" + assert tool.description == "A simple function to get the current weather for a location." + assert tool.parameters == { + "type": "object", + "properties": { + "city": {"type": "string", "description": "the city for which to get the weather", "default": "Munich"}, + "unit": { + "type": "string", + "enum": ["Celsius", "Fahrenheit"], + "description": "the unit for the temperature", + "default": "Celsius", + }, + "nullable_param": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "description": "a nullable parameter", + "default": None, + }, + }, + } + + +def test_deserialize_tools_inplace(): + tool = Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report) + serialized_tool = tool.to_dict() + print(serialized_tool) + + data = {"tools": [serialized_tool.copy()]} + deserialize_tools_inplace(data) + assert data["tools"] == [tool] + + data = {"mytools": [serialized_tool.copy()]} + deserialize_tools_inplace(data, key="mytools") + assert data["mytools"] == [tool] + + data = {"no_tools": 123} + deserialize_tools_inplace(data) + assert data == {"no_tools": 123} + + +def test_deserialize_tools_inplace_failures(): + data = {"key": "value"} + deserialize_tools_inplace(data) + assert data == {"key": "value"} + + data = {"tools": None} + deserialize_tools_inplace(data) + assert data == {"tools": None} + + data = {"tools": "not a list"} + with pytest.raises(TypeError): + deserialize_tools_inplace(data) + + data = {"tools": ["not a dictionary"]} + with pytest.raises(TypeError): + deserialize_tools_inplace(data) + + +def test_remove_title_from_schema(): + complex_schema = { + "properties": { + "parameter1": { + "anyOf": [{"type": "string"}, {"type": "integer"}], + "default": "default_value", + "title": "Parameter1", + }, + "parameter2": { + "default": [1, 2, 3], + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "title": "Parameter2", + "type": "array", + }, + "parameter3": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + {"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"}, + ], + "default": 42, + "title": "Parameter3", + }, + "parameter4": { + "anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}], + "default": {"key": "value"}, + "title": "Parameter4", + }, + }, + "title": "complex_function", + "type": "object", + } + + _remove_title_from_schema(complex_schema) + + assert complex_schema == { + "properties": { + "parameter1": {"anyOf": [{"type": "string"}, {"type": "integer"}], "default": "default_value"}, + "parameter2": { + "default": [1, 2, 3], + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "type": "array", + }, + "parameter3": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + {"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"}, + ], + "default": 42, + }, + "parameter4": { + "anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}], + "default": {"key": "value"}, + }, + }, + "type": "object", + } + + +def test_remove_title_from_schema_do_not_remove_title_property(): + """Test that the utility function only removes the 'title' keywords and not the 'title' property (if present).""" + schema = { + "properties": { + "parameter1": {"type": "string", "title": "Parameter1"}, + "title": {"type": "string", "title": "Title"}, + }, + "title": "complex_function", + "type": "object", + } + + _remove_title_from_schema(schema) + + assert schema == {"properties": {"parameter1": {"type": "string"}, "title": {"type": "string"}}, "type": "object"} + + +def test_remove_title_from_schema_handle_no_title_in_top_level(): + schema = { + "properties": { + "parameter1": {"type": "string", "title": "Parameter1"}, + "parameter2": {"type": "integer", "title": "Parameter2"}, + }, + "type": "object", + } + + _remove_title_from_schema(schema) + + assert schema == { + "properties": {"parameter1": {"type": "string"}, "parameter2": {"type": "integer"}}, + "type": "object", + }