Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding the OpenAIFunctionCaller #14

Merged
merged 18 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions haystack_experimental/components/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
3 changes: 3 additions & 0 deletions haystack_experimental/components/tools/openai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
105 changes: 105 additions & 0 deletions haystack_experimental/components/tools/openai/function_caller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import json
from typing import Any, List, Dict, Callable

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import ChatMessage
from haystack.utils import serialize_callable, deserialize_callable

_FUNCTION_NAME_FAILURE = "I'm sorry, I tried to run a function that did not exist. Would you like me to correct it and try again?"
_FUNCTION_RUN_FAILURE = "Seems there was an error while runnign the function: {error}"
TuanaCelik marked this conversation as resolved.
Show resolved Hide resolved


@component
class OpenAIFunctionCaller:
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
"""
The OpenAIFunctionCaller expects a list of ChatMessages and if there is a tool call with a function name and arguments, it runs the function and returns the
result as a ChatMessage from role = 'function'
"""

def __init__(self, available_functions: Dict[str, Callable]):
"""
Initialize the OpenAIFunctionCaller component.
:param available_functions: A dictionary of available functions. This dictionary expects key value pairs of function name, and the function itself. For example, `{"weather_function": weather_function}`
"""
self.available_functions = available_functions
TuanaCelik marked this conversation as resolved.
Show resolved Hide resolved

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
available_function_callbacks = {}
for function in self.available_functions:
available_function_callbacks[function] = (
serialize_callable(self.available_functions[function])
if function
else None
)
serialization_dict = default_to_dict(
self, available_functions=available_function_callbacks
)
TuanaCelik marked this conversation as resolved.
Show resolved Hide resolved
return serialization_dict

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OpenAIFunctionCaller":
"""
Deserializes the component from a dictionary.

:param data:
The dictionary to deserialize from.
:returns:
The deserialized component.
"""
init_params = data.get("init_parameters", {})
available_function_callback_handler = init_params.get("available_functions")
if available_function_callback_handler:
for callback in data["init_parameters"]["available_functions"]:
print(callback)
deserialize_callable(
data["init_parameters"]["available_functions"][callback]
)
TuanaCelik marked this conversation as resolved.
Show resolved Hide resolved

return default_from_dict(cls, data)
TuanaCelik marked this conversation as resolved.
Show resolved Hide resolved

@component.output_types(
function_replies=List[ChatMessage], assistant_replies=List[ChatMessage]
)
def run(self, messages: List[ChatMessage]):
"""
Evaluates `messages` and invokes available functions if the messages contain tool_calls.
:param messages: A list of messages generated from the `OpenAIChatGenerator`
:returns: This component returns a list of messages in one of two outputs
- `function_replies`: List of ChatMessages containing the result of a function invocation. This message comes from role = 'function'. If the function name was hallucinated or wrong, an assistant message explaining as such is returned
- `assistant_replies`: List of ChatMessages containing a regular assistant reply. In this case, there were no tool_calls in the received messages
"""
if messages[0].meta["finish_reason"] == "tool_calls":
function_calls = json.loads(messages[0].content)
for function_call in function_calls:
function_name = function_call["function"]["name"]
function_args = json.loads(function_call["function"]["arguments"])
if function_name in self.available_functions:
function_to_call = self.available_functions[function_name]
try:
function_response = function_to_call(**function_args)
messages.append(
ChatMessage.from_function(
content=json.dumps(function_response),
name=function_name,
)
)
except BaseException as e:
messages.append(
ChatMessage.from_assistant(
_FUNCTION_RUN_FAILURE.format(error=e)
)
)
else:
messages.append(ChatMessage.from_assistant(_FUNCTION_NAME_FAILURE))
return {"function_replies": messages}
return {"assistant_replies": messages}
Loading