Skip to content

Commit

Permalink
Merge branch 'main' into add-recursive-chunking
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsbatista authored Jan 3, 2025
2 parents e398120 + 8e3f647 commit 6977b2a
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 8 deletions.
4 changes: 0 additions & 4 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,6 @@ def _prepare_api_call( # noqa: PLR0913
}

def _handle_stream_response(self, chat_completion: Stream, callback: StreamingCallbackT) -> List[ChatMessage]:
print("callback")
print(callback)
print("-" * 100)

chunks: List[StreamingChunk] = []
chunk = None

Expand Down
78 changes: 78 additions & 0 deletions haystack/dataclasses/chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,81 @@ def to_openai_dict_format(self) -> Dict[str, Any]:
)
openai_msg["tool_calls"] = openai_tool_calls
return openai_msg

@staticmethod
def _validate_openai_message(message: Dict[str, Any]) -> None:
"""
Validate that a message dictionary follows OpenAI's Chat API format.
:param message: The message dictionary to validate
:raises ValueError: If the message format is invalid
"""
if "role" not in message:
raise ValueError("The `role` field is required in the message dictionary.")

role = message["role"]
content = message.get("content")
tool_calls = message.get("tool_calls")

if role not in ["assistant", "user", "system", "developer", "tool"]:
raise ValueError(f"Unsupported role: {role}")

if role == "assistant":
if not content and not tool_calls:
raise ValueError("For assistant messages, either `content` or `tool_calls` must be present.")
if tool_calls:
for tc in tool_calls:
if "function" not in tc:
raise ValueError("Tool calls must contain the `function` field")
elif not content:
raise ValueError(f"The `content` field is required for {role} messages.")

@classmethod
def from_openai_dict_format(cls, message: Dict[str, Any]) -> "ChatMessage":
"""
Create a ChatMessage from a dictionary in the format expected by OpenAI's Chat API.
NOTE: While OpenAI's API requires `tool_call_id` in both tool calls and tool messages, this method
accepts messages without it to support shallow OpenAI-compatible APIs.
If you plan to use the resulting ChatMessage with OpenAI, you must include `tool_call_id` or you'll
encounter validation errors.
:param message:
The OpenAI dictionary to build the ChatMessage object.
:returns:
The created ChatMessage object.
:raises ValueError:
If the message dictionary is missing required fields.
"""
cls._validate_openai_message(message)

role = message["role"]
content = message.get("content")
name = message.get("name")
tool_calls = message.get("tool_calls")
tool_call_id = message.get("tool_call_id")

if role == "assistant":
haystack_tool_calls = None
if tool_calls:
haystack_tool_calls = []
for tc in tool_calls:
haystack_tc = ToolCall(
id=tc.get("id"),
tool_name=tc["function"]["name"],
arguments=json.loads(tc["function"]["arguments"]),
)
haystack_tool_calls.append(haystack_tc)
return cls.from_assistant(text=content, name=name, tool_calls=haystack_tool_calls)

assert content is not None # ensured by _validate_openai_message, but we need to make mypy happy

if role == "user":
return cls.from_user(text=content, name=name)
if role in ["system", "developer"]:
return cls.from_system(text=content, name=name)

return cls.from_tool(
tool_result=content, origin=ToolCall(id=tool_call_id, tool_name="", arguments={}), error=False
)
9 changes: 5 additions & 4 deletions haystack/utils/callable_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# SPDX-License-Identifier: Apache-2.0

import inspect
import sys
from typing import Callable, Optional

from haystack import DeserializationError
from haystack.utils.type_serialization import thread_safe_import


def serialize_callable(callable_handle: Callable) -> str:
Expand Down Expand Up @@ -37,9 +37,10 @@ def deserialize_callable(callable_handle: str) -> Optional[Callable]:
parts = callable_handle.split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module = sys.modules.get(module_name, None)
if not module:
raise DeserializationError(f"Could not locate the module of the callable: {module_name}")
try:
module = thread_safe_import(module_name)
except Exception as e:
raise DeserializationError(f"Could not locate the module of the callable: {module_name}") from e
deserialized_callable = getattr(module, function_name, None)
if not deserialized_callable:
raise DeserializationError(f"Could not locate the callable: {function_name}")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Add the `from_openai_dict_format` class method to the `ChatMessage` class. It allows you to create a `ChatMessage`
from a dictionary in the format expected by OpenAI's Chat API.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
enhancements:
- |
Improved deserialization of callables by using `importlib` instead of `sys.modules`.
This change allows importing local functions and classes that are not in `sys.modules`
when deserializing callables.
80 changes: 80 additions & 0 deletions test/dataclasses/test_chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,86 @@ def test_to_openai_dict_format_invalid():
message.to_openai_dict_format()


def test_from_openai_dict_format_user_message():
openai_msg = {"role": "user", "content": "Hello, how are you?", "name": "John"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "user"
assert message.text == "Hello, how are you?"
assert message.name == "John"


def test_from_openai_dict_format_system_message():
openai_msg = {"role": "system", "content": "You are a helpful assistant"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "system"
assert message.text == "You are a helpful assistant"


def test_from_openai_dict_format_assistant_message_with_content():
openai_msg = {"role": "assistant", "content": "I can help with that"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "assistant"
assert message.text == "I can help with that"


def test_from_openai_dict_format_assistant_message_with_tool_calls():
openai_msg = {
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_123", "function": {"name": "get_weather", "arguments": '{"location": "Berlin"}'}}],
}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "assistant"
assert message.text is None
assert len(message.tool_calls) == 1
tool_call = message.tool_calls[0]
assert tool_call.id == "call_123"
assert tool_call.tool_name == "get_weather"
assert tool_call.arguments == {"location": "Berlin"}


def test_from_openai_dict_format_tool_message():
openai_msg = {"role": "tool", "content": "The weather is sunny", "tool_call_id": "call_123"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "tool"
assert message.tool_call_result.result == "The weather is sunny"
assert message.tool_call_result.origin.id == "call_123"


def test_from_openai_dict_format_tool_without_id():
openai_msg = {"role": "tool", "content": "The weather is sunny"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "tool"
assert message.tool_call_result.result == "The weather is sunny"
assert message.tool_call_result.origin.id is None


def test_from_openai_dict_format_missing_role():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"content": "test"})


def test_from_openai_dict_format_missing_content():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"role": "user"})


def test_from_openai_dict_format_invalid_tool_calls():
openai_msg = {"role": "assistant", "tool_calls": [{"invalid": "format"}]}
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format(openai_msg)


def test_from_openai_dict_format_unsupported_role():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"role": "invalid", "content": "test"})


def test_from_openai_dict_format_assistant_missing_content_and_tool_calls():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"role": "assistant", "irrelevant": "irrelevant"})


@pytest.mark.integration
def test_apply_chat_templating_on_chat_message():
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
Expand Down
9 changes: 9 additions & 0 deletions test/utils/test_callable_serialization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
import requests

from haystack import DeserializationError
from haystack.components.generators.utils import print_streaming_chunk
from haystack.utils import serialize_callable, deserialize_callable

Expand Down Expand Up @@ -36,3 +38,10 @@ def test_callable_deserialization_non_local():
result = serialize_callable(requests.api.get)
fn = deserialize_callable(result)
assert fn is requests.api.get


def test_callable_deserialization_error():
with pytest.raises(DeserializationError):
deserialize_callable("this.is.not.a.valid.module")
with pytest.raises(DeserializationError):
deserialize_callable("sys.foobar")

0 comments on commit 6977b2a

Please sign in to comment.