From 4b941b06544651b0cc4ea788ee7b66fb21c0f519 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 13 Dec 2024 08:46:55 +0100 Subject: [PATCH 01/12] draft --- .../builders/chat_prompt_builder.py | 6 +- haystack/dataclasses/__init__.py | 5 +- haystack/dataclasses/chat_message.py | 265 ++++++++++++++---- .../builders/test_chat_prompt_builder.py | 12 +- .../routers/test_conditional_router.py | 6 +- test/dataclasses/test_chat_message.py | 232 ++++++++++----- 6 files changed, 394 insertions(+), 132 deletions(-) diff --git a/haystack/components/builders/chat_prompt_builder.py b/haystack/components/builders/chat_prompt_builder.py index fd9969f5b7..33e2feda2d 100644 --- a/haystack/components/builders/chat_prompt_builder.py +++ b/haystack/components/builders/chat_prompt_builder.py @@ -9,7 +9,7 @@ from jinja2.sandbox import SandboxedEnvironment from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses.chat_message import ChatMessage, ChatRole +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent logger = logging.getLogger(__name__) @@ -197,10 +197,10 @@ def run( if message.text is None: raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}") compiled_template = self._env.from_string(message.text) - rendered_content = compiled_template.render(template_variables_combined) + rendered_text = compiled_template.render(template_variables_combined) # deep copy the message to avoid modifying the original message rendered_message: ChatMessage = deepcopy(message) - rendered_message.content = rendered_content + rendered_message._content = [TextContent(text=rendered_text)] processed_messages.append(rendered_message) else: processed_messages.append(message) diff --git a/haystack/dataclasses/__init__.py b/haystack/dataclasses/__init__.py index 231ce80713..bcdd6acdd7 100644 --- a/haystack/dataclasses/__init__.py +++ b/haystack/dataclasses/__init__.py @@ -4,7 +4,7 @@ from haystack.dataclasses.answer import Answer, ExtractedAnswer, GeneratedAnswer from haystack.dataclasses.byte_stream import ByteStream -from haystack.dataclasses.chat_message import ChatMessage, ChatRole +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall, ToolCallResult, TextContent from haystack.dataclasses.document import Document from haystack.dataclasses.sparse_embedding import SparseEmbedding from haystack.dataclasses.streaming_chunk import StreamingChunk @@ -17,6 +17,9 @@ "ByteStream", "ChatMessage", "ChatRole", + "ToolCall", + "ToolCallResult", + "TextContent", "StreamingChunk", "SparseEmbedding", ] diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index fb15ee6f5e..e0a437513f 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -5,16 +5,81 @@ import warnings from dataclasses import asdict, dataclass, field from enum import Enum -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union, List, Sequence class ChatRole(str, Enum): - """Enumeration representing the roles within a chat.""" + """ + Enumeration representing the roles within a chat. + """ - ASSISTANT = "assistant" + #: The user role. A message from the user contains only text. USER = "user" + + #: The system role. A message from the system contains only text. SYSTEM = "system" - FUNCTION = "function" + + #: The assistant role. A message from the assistant can contain text and Tool calls. It can also store metadata. + ASSISTANT = "assistant" + + #: The tool role. A message from a tool contains the result of a Tool invocation. + TOOL = "tool" + + @staticmethod + def from_str(string: str) -> "ChatRole": + """ + Convert a string to a ChatRole enum. + """ + enum_map = {e.value: e for e in ChatRole} + role = enum_map.get(string) + if role is None: + msg = f"Unknown chat role '{string}'. Supported roles are: {list(enum_map.keys())}" + raise ValueError(msg) + return role + + +@dataclass +class ToolCall: + """ + Represents a Tool call prepared by the model, usually contained in an assistant message. + + :param id: The ID of the Tool call. + :param tool_name: The name of the Tool to call. + :param arguments: The arguments to call the Tool with. + """ + + tool_name: str + arguments: Dict[str, Any] + id: Optional[str] = None # noqa: A003 + + +@dataclass +class ToolCallResult: + """ + Represents the result of a Tool invocation. + + :param result: The result of the Tool invocation. + :param origin: The Tool call that produced this result. + :param error: Whether the Tool invocation resulted in an error. + """ + + result: str + origin: ToolCall + error: bool + + +@dataclass +class TextContent: + """ + The textual content of a chat message. + + :param text: The text content of the message. + """ + + text: str + + +ChatMessageContentT = Union[TextContent, ToolCall, ToolCallResult] @dataclass @@ -22,26 +87,88 @@ class ChatMessage: """ Represents a message in a LLM chat conversation. - :param content: The text content of the message. - :param role: The role of the entity sending the message. - :param name: The name of the function being called (only applicable for role FUNCTION). - :param meta: Additional metadata associated with the message. + Use the `from_assistant`, `from_user`, `from_system`, and `from_tool` class methods to create a ChatMessage. """ - content: str - role: ChatRole - name: Optional[str] - meta: Dict[str, Any] = field(default_factory=dict, hash=False) + _role: ChatRole + _content: Sequence[ChatMessageContentT] + _meta: Dict[str, Any] = field(default_factory=dict, hash=False) + + def __len__(self): + return len(self._content) + + @property + def role(self) -> ChatRole: + """ + Returns the role of the entity sending the message. + """ + return self._role + + @property + def meta(self) -> Dict[str, Any]: + """ + Returns the metadata associated with the message. + """ + return self._meta + + @property + def texts(self) -> List[str]: + """ + Returns the list of all texts contained in the message. + """ + return [content.text for content in self._content if isinstance(content, TextContent)] @property def text(self) -> Optional[str]: """ - Returns the textual content of the message. + Returns the first text contained in the message. + """ + if texts := self.texts: + return texts[0] + return None + + @property + def tool_calls(self) -> List[ToolCall]: + """ + Returns the list of all Tool calls contained in the message. + """ + return [content for content in self._content if isinstance(content, ToolCall)] + + @property + def tool_call(self) -> Optional[ToolCall]: + """ + Returns the first Tool call contained in the message. """ - # Currently, this property mirrors the `content` attribute. This will change in 2.9.0. - # The current actual return type is str. We are using Optional[str] to be ready for 2.9.0, - # when None will be a valid value for `text`. - return object.__getattribute__(self, "content") + if tool_calls := self.tool_calls: + return tool_calls[0] + return None + + @property + def tool_call_results(self) -> List[ToolCallResult]: + """ + Returns the list of all Tool call results contained in the message. + """ + return [content for content in self._content if isinstance(content, ToolCallResult)] + + @property + def tool_call_result(self) -> Optional[ToolCallResult]: + """ + Returns the first Tool call result contained in the message. + """ + if tool_call_results := self.tool_call_results: + return tool_call_results[0] + return None + + def is_from(self, role: Union[ChatRole, str]) -> bool: + """ + Check if the message is from a specific role. + + :param role: The role to check against. + :returns: True if the message is from the specified role, False otherwise. + """ + if isinstance(role, str): + role = ChatRole.from_str(role) + return self._role == role def __getattribute__(self, name): # this method is reimplemented to warn about the deprecation of the `content` attribute @@ -53,56 +180,69 @@ def __getattribute__(self, name): warnings.warn(msg, DeprecationWarning) return object.__getattribute__(self, name) - def is_from(self, role: ChatRole) -> bool: - """ - Check if the message is from a specific role. - - :param role: The role to check against. - :returns: True if the message is from the specified role, False otherwise. - """ - return self.role == role - @classmethod - def from_assistant(cls, content: str, meta: Optional[Dict[str, Any]] = None) -> "ChatMessage": + def from_user(cls, text: str, meta: Optional[Dict[str, Any]] = None) -> "ChatMessage": """ - Create a message from the assistant. + Create a message from the user. - :param content: The text content of the message. + :param text: The text content of the message. :param meta: Additional metadata associated with the message. :returns: A new ChatMessage instance. """ - return cls(content, ChatRole.ASSISTANT, None, meta or {}) + return cls(_role=ChatRole.USER, _content=[TextContent(text=text)], _meta=meta or {}) @classmethod - def from_user(cls, content: str) -> "ChatMessage": + def from_system(cls, text: str, meta: Optional[Dict[str, Any]] = None) -> "ChatMessage": """ - Create a message from the user. + Create a message from the system. - :param content: The text content of the message. + :param text: The text content of the message. + :param meta: Additional metadata associated with the message. :returns: A new ChatMessage instance. """ - return cls(content, ChatRole.USER, None) + return cls(_role=ChatRole.SYSTEM, _content=[TextContent(text=text)], _meta=meta or {}) @classmethod - def from_system(cls, content: str) -> "ChatMessage": + def from_assistant( + cls, + text: Optional[str] = None, + meta: Optional[Dict[str, Any]] = None, + tool_calls: Optional[List[ToolCall]] = None, + ) -> "ChatMessage": """ - Create a message from the system. + Create a message from the assistant. - :param content: The text content of the message. + :param text: The text content of the message. + :param meta: Additional metadata associated with the message. + :param tool_calls: The Tool calls to include in the message. :returns: A new ChatMessage instance. """ - return cls(content, ChatRole.SYSTEM, None) + content: List[ChatMessageContentT] = [] + if text is not None: + content.append(TextContent(text=text)) + if tool_calls: + content.extend(tool_calls) + + return cls(_role=ChatRole.ASSISTANT, _content=content, _meta=meta or {}) @classmethod - def from_function(cls, content: str, name: str) -> "ChatMessage": + def from_tool( + cls, tool_result: str, origin: ToolCall, error: bool = False, meta: Optional[Dict[str, Any]] = None + ) -> "ChatMessage": """ - Create a message from a function call. + Create a message from a Tool. - :param content: The text content of the message. - :param name: The name of the function being called. + :param tool_result: The result of the Tool invocation. + :param origin: The Tool call that produced this result. + :param error: Whether the Tool invocation resulted in an error. + :param meta: Additional metadata associated with the message. :returns: A new ChatMessage instance. """ - return cls(content, ChatRole.FUNCTION, name) + return cls( + _role=ChatRole.TOOL, + _content=[ToolCallResult(result=tool_result, origin=origin, error=error)], + _meta=meta or {}, + ) def to_dict(self) -> Dict[str, Any]: """ @@ -111,10 +251,23 @@ def to_dict(self) -> Dict[str, Any]: :returns: Serialized version of the object. """ - data = asdict(self) - data["role"] = self.role.value + serialized: Dict[str, Any] = {} + serialized["_role"] = self._role.value + serialized["_meta"] = self._meta + + content: List[Dict[str, Any]] = [] + for part in self._content: + if isinstance(part, TextContent): + content.append({"text": part.text}) + elif isinstance(part, ToolCall): + content.append({"tool_call": asdict(part)}) + elif isinstance(part, ToolCallResult): + content.append({"tool_call_result": asdict(part)}) + else: + raise TypeError(f"Unsupported type in ChatMessage content: `{type(part).__name__}` for `{part}`.") - return data + serialized["_content"] = content + return serialized @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage": @@ -126,6 +279,24 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage": :returns: The created object. """ - data["role"] = ChatRole(data["role"]) + data["_role"] = ChatRole(data["_role"]) + + content: List[ChatMessageContentT] = [] + + for part in data["_content"]: + if "text" in part: + content.append(TextContent(text=part["text"])) + elif "tool_call" in part: + content.append(ToolCall(**part["tool_call"])) + elif "tool_call_result" in part: + result = part["tool_call_result"]["result"] + origin = ToolCall(**part["tool_call_result"]["origin"]) + error = part["tool_call_result"]["error"] + tcr = ToolCallResult(result=result, origin=origin, error=error) + content.append(tcr) + else: + raise ValueError(f"Unsupported content in serialized ChatMessage: `{part}`") + + data["_content"] = content return cls(**data) diff --git a/test/components/builders/test_chat_prompt_builder.py b/test/components/builders/test_chat_prompt_builder.py index 5e1ae6132e..f981afa1b3 100644 --- a/test/components/builders/test_chat_prompt_builder.py +++ b/test/components/builders/test_chat_prompt_builder.py @@ -13,8 +13,8 @@ class TestChatPromptBuilder: def test_init(self): builder = ChatPromptBuilder( template=[ - ChatMessage.from_user(content="This is a {{ variable }}"), - ChatMessage.from_system(content="This is a {{ variable2 }}"), + ChatMessage.from_user("This is a {{ variable }}"), + ChatMessage.from_system("This is a {{ variable2 }}"), ] ) assert builder.required_variables == [] @@ -531,8 +531,8 @@ def test_to_dict(self): "type": "haystack.components.builders.chat_prompt_builder.ChatPromptBuilder", "init_parameters": { "template": [ - {"content": "text and {var}", "role": "user", "name": None, "meta": {}}, - {"content": "content {required_var}", "role": "assistant", "name": None, "meta": {}}, + {"_content": [{"text": "text and {var}"}], "_role": "user", "_meta": {}}, + {"_content": [{"text": "content {required_var}"}], "_role": "assistant", "_meta": {}}, ], "variables": ["var", "required_var"], "required_variables": ["required_var"], @@ -545,8 +545,8 @@ def test_from_dict(self): "type": "haystack.components.builders.chat_prompt_builder.ChatPromptBuilder", "init_parameters": { "template": [ - {"content": "text and {var}", "role": "user", "name": None, "meta": {}}, - {"content": "content {required_var}", "role": "assistant", "name": None, "meta": {}}, + {"_content": [{"text": "text and {var}"}], "_role": "user", "_meta": {}}, + {"_content": [{"text": "content {required_var}"}], "_role": "assistant", "_meta": {}}, ], "variables": ["var", "required_var"], "required_variables": ["required_var"], diff --git a/test/components/routers/test_conditional_router.py b/test/components/routers/test_conditional_router.py index e0f3552319..66d941b645 100644 --- a/test/components/routers/test_conditional_router.py +++ b/test/components/routers/test_conditional_router.py @@ -349,7 +349,7 @@ def test_unsafe(self): ] router = ConditionalRouter(routes, unsafe=True) streams = [1] - message = ChatMessage.from_user(content="This is a message") + message = ChatMessage.from_user("This is a message") res = router.run(streams=streams, message=message) assert res == {"message": message} @@ -370,7 +370,7 @@ def test_validate_output_type_without_unsafe(self): ] router = ConditionalRouter(routes, validate_output_type=True) streams = [1] - message = ChatMessage.from_user(content="This is a message") + message = ChatMessage.from_user("This is a message") with pytest.raises(ValueError, match="Route 'message' type doesn't match expected type"): router.run(streams=streams, message=message) @@ -391,7 +391,7 @@ def test_validate_output_type_with_unsafe(self): ] router = ConditionalRouter(routes, unsafe=True, validate_output_type=True) streams = [1] - message = ChatMessage.from_user(content="This is a message") + message = ChatMessage.from_user("This is a message") res = router.run(streams=streams, message=message) assert isinstance(res["message"], ChatMessage) diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index 30ad51630e..c3357882b4 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -4,64 +4,189 @@ import pytest from transformers import AutoTokenizer -from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall, ToolCallResult, TextContent from haystack.components.generators.openai_utils import _convert_message_to_openai_format +def test_tool_call_init(): + tc = ToolCall(id="123", tool_name="mytool", arguments={"a": 1}) + assert tc.id == "123" + assert tc.tool_name == "mytool" + assert tc.arguments == {"a": 1} + + +def test_tool_call_result_init(): + tcr = ToolCallResult(result="result", origin=ToolCall(id="123", tool_name="mytool", arguments={"a": 1}), error=True) + assert tcr.result == "result" + assert tcr.origin == ToolCall(id="123", tool_name="mytool", arguments={"a": 1}) + assert tcr.error + + +def test_text_content_init(): + tc = TextContent(text="Hello") + assert tc.text == "Hello" + + def test_from_assistant_with_valid_content(): - content = "Hello, how can I assist you?" - message = ChatMessage.from_assistant(content) - assert message.content == content - assert message.text == content + text = "Hello, how can I assist you?" + message = ChatMessage.from_assistant(text) + + assert message._role == ChatRole.ASSISTANT + assert message._content == [TextContent(text)] + + assert message.text == text + assert message.texts == [text] + + assert not message.tool_calls + assert not message.tool_call + assert not message.tool_call_results + assert not message.tool_call_result + + +def test_from_assistant_with_tool_calls(): + tool_calls = [ + ToolCall(id="123", tool_name="mytool", arguments={"a": 1}), + ToolCall(id="456", tool_name="mytool2", arguments={"b": 2}), + ] + + message = ChatMessage.from_assistant(tool_calls=tool_calls) + assert message.role == ChatRole.ASSISTANT + assert message._content == tool_calls + + assert message.tool_calls == tool_calls + assert message.tool_call == tool_calls[0] + + assert not message.texts + assert not message.text + assert not message.tool_call_results + assert not message.tool_call_result def test_from_user_with_valid_content(): - content = "I have a question." - message = ChatMessage.from_user(content) - assert message.content == content - assert message.text == content + text = "I have a question." + message = ChatMessage.from_user(text=text) + assert message.role == ChatRole.USER + assert message._content == [TextContent(text)] + + assert message.text == text + assert message.texts == [text] + + assert not message.tool_calls + assert not message.tool_call + assert not message.tool_call_results + assert not message.tool_call_result def test_from_system_with_valid_content(): - content = "System message." - message = ChatMessage.from_system(content) - assert message.content == content - assert message.text == content + text = "I have a question." + message = ChatMessage.from_system(text=text) + assert message.role == ChatRole.SYSTEM + assert message._content == [TextContent(text)] + assert message.text == text + assert message.texts == [text] -def test_with_empty_content(): - message = ChatMessage.from_user("") - assert message.content == "" - assert message.text == "" - assert message.role == ChatRole.USER + assert not message.tool_calls + assert not message.tool_call + assert not message.tool_call_results + assert not message.tool_call_result + + +def test_from_tool_with_valid_content(): + tool_result = "Tool result" + origin = ToolCall(id="123", tool_name="mytool", arguments={"a": 1}) + message = ChatMessage.from_tool(tool_result, origin, error=False) + + tcr = ToolCallResult(result=tool_result, origin=origin, error=False) + + assert message._content == [tcr] + assert message.role == ChatRole.TOOL + + assert message.tool_call_result == tcr + assert message.tool_call_results == [tcr] + + assert not message.tool_calls + assert not message.tool_call + assert not message.texts + assert not message.text + + +def test_multiple_text_segments(): + texts = [TextContent(text="Hello"), TextContent(text="World")] + message = ChatMessage(_role=ChatRole.USER, _content=texts) + assert message.texts == ["Hello", "World"] + assert len(message) == 2 -def test_from_function_with_empty_name(): - content = "Function call" - message = ChatMessage.from_function(content, "") - assert message.content == content - assert message.text == content - assert message.name == "" - assert message.role == ChatRole.FUNCTION +def test_mixed_content(): + content = [TextContent(text="Hello"), ToolCall(id="123", tool_name="mytool", arguments={"a": 1})] -def test_to_openai_format(): - message = ChatMessage.from_system("You are good assistant") - assert _convert_message_to_openai_format(message) == {"role": "system", "content": "You are good assistant"} + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=content) - message = ChatMessage.from_user("I have a question") - assert _convert_message_to_openai_format(message) == {"role": "user", "content": "I have a question"} + assert len(message) == 2 + assert message.texts == ["Hello"] + assert message.text == "Hello" - message = ChatMessage.from_function("Function call", "function_name") - assert _convert_message_to_openai_format(message) == { - "role": "function", - "content": "Function call", - "name": "function_name", + assert message.tool_calls == [content[1]] + assert message.tool_call == content[1] + + +def test_serde(): + # the following message is created just for testing purposes and does not make sense in a real use case + + role = ChatRole.ASSISTANT + + text_content = TextContent(text="Hello") + tool_call = ToolCall(id="123", tool_name="mytool", arguments={"a": 1}) + tool_call_result = ToolCallResult(result="result", origin=tool_call, error=False) + meta = {"some": "info"} + + message = ChatMessage(_role=role, _content=[text_content, tool_call, tool_call_result], _meta=meta) + + serialized_message = message.to_dict() + assert serialized_message == { + "_content": [ + {"text": "Hello"}, + {"tool_call": {"id": "123", "tool_name": "mytool", "arguments": {"a": 1}}}, + { + "tool_call_result": { + "result": "result", + "error": False, + "origin": {"id": "123", "tool_name": "mytool", "arguments": {"a": 1}}, + } + }, + ], + "_role": "assistant", + "_meta": {"some": "info"}, } + deserialized_message = ChatMessage.from_dict(serialized_message) + assert deserialized_message == message + + +def test_to_dict_with_invalid_content_type(): + text_content = TextContent(text="Hello") + invalid_content = "invalid" + + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[text_content, invalid_content]) + + with pytest.raises(TypeError): + message.to_dict() + + +def test_from_dict_with_invalid_content_type(): + data = {"_role": "assistant", "_content": [{"text": "Hello"}, "invalid"]} + with pytest.raises(ValueError): + ChatMessage.from_dict(data) + + data = {"_role": "assistant", "_content": [{"text": "Hello"}, {"invalid": "invalid"}]} + with pytest.raises(ValueError): + ChatMessage.from_dict(data) + @pytest.mark.integration def test_apply_chat_templating_on_chat_message(): @@ -93,40 +218,3 @@ def test_apply_custom_chat_templating_on_chat_message(): formatted_messages, chat_template=anthropic_template, tokenize=False ) assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:" - - -def test_to_dict(): - content = "content" - role = "user" - meta = {"some": "some"} - - message = ChatMessage.from_user(content) - message.meta.update(meta) - - assert message.text == content - assert message.to_dict() == {"content": content, "role": role, "name": None, "meta": meta} - - -def test_from_dict(): - assert ChatMessage.from_dict(data={"content": "text", "role": "user", "name": None}) == ChatMessage.from_user( - "text" - ) - - -def test_from_dict_with_meta(): - data = {"content": "text", "role": "assistant", "name": None, "meta": {"something": "something"}} - assert ChatMessage.from_dict(data) == ChatMessage.from_assistant("text", meta={"something": "something"}) - - -def test_content_deprecation_warning(recwarn): - message = ChatMessage.from_user("my message") - - # accessing the content attribute triggers the deprecation warning - _ = message.content - assert len(recwarn) == 1 - wrn = recwarn.pop(DeprecationWarning) - assert "`content` attribute" in wrn.message.args[0] - - # accessing the text property does not trigger a warning - assert message.text == "my message" - assert len(recwarn) == 0 From 9ec7de73849b124bb527091bd99b88d13c13ccd0 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 13 Dec 2024 08:56:04 +0100 Subject: [PATCH 02/12] del HF token in tests --- .../classifiers/test_zero_shot_document_classifier.py | 2 ++ test/components/generators/chat/test_hugging_face_local.py | 1 + .../components/generators/test_hugging_face_local_generator.py | 1 + .../components/rankers/test_sentence_transformers_diversity.py | 2 +- test/components/rankers/test_transformers_similarity.py | 1 + test/components/readers/test_extractive.py | 3 +++ test/components/routers/test_transformers_text_router.py | 3 +++ test/components/routers/test_zero_shot_text_router.py | 2 ++ 8 files changed, 14 insertions(+), 1 deletion(-) diff --git a/test/components/classifiers/test_zero_shot_document_classifier.py b/test/components/classifiers/test_zero_shot_document_classifier.py index 7d679e3d21..be4d04a9fe 100644 --- a/test/components/classifiers/test_zero_shot_document_classifier.py +++ b/test/components/classifiers/test_zero_shot_document_classifier.py @@ -45,6 +45,7 @@ def test_to_dict(self): def test_from_dict(self, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", "init_parameters": { @@ -73,6 +74,7 @@ def test_from_dict(self, monkeypatch): def test_from_dict_no_default_parameters(self, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", "init_parameters": {"model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"]}, diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 433917ec23..8f6749c2d8 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -166,6 +166,7 @@ def test_from_dict(self, model_info_mock): @patch("haystack.components.generators.chat.hugging_face_local.pipeline") def test_warm_up(self, pipeline_mock, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) generator = HuggingFaceLocalChatGenerator( model="mistralai/Mistral-7B-Instruct-v0.2", task="text2text-generation", diff --git a/test/components/generators/test_hugging_face_local_generator.py b/test/components/generators/test_hugging_face_local_generator.py index 5c3b162a31..bded2e8d47 100644 --- a/test/components/generators/test_hugging_face_local_generator.py +++ b/test/components/generators/test_hugging_face_local_generator.py @@ -18,6 +18,7 @@ class TestHuggingFaceLocalGenerator: @patch("haystack.utils.hf.model_info") def test_init_default(self, model_info_mock, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) model_info_mock.return_value.pipeline_tag = "text2text-generation" generator = HuggingFaceLocalGenerator() diff --git a/test/components/rankers/test_sentence_transformers_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py index eabd2ac375..018b443987 100644 --- a/test/components/rankers/test_sentence_transformers_diversity.py +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -273,7 +273,7 @@ def test_warm_up(self, similarity, monkeypatch): Test that ranker loads the SentenceTransformer model correctly during warm up. """ monkeypatch.delenv("HF_API_TOKEN", raising=False) - + monkeypatch.delenv("HF_TOKEN", raising=False) mock_model_class = MagicMock() mock_model_instance = MagicMock() mock_model_class.return_value = mock_model_instance diff --git a/test/components/rankers/test_transformers_similarity.py b/test/components/rankers/test_transformers_similarity.py index 6031d85e15..616bfa6647 100644 --- a/test/components/rankers/test_transformers_similarity.py +++ b/test/components/rankers/test_transformers_similarity.py @@ -313,6 +313,7 @@ def test_device_map_and_device_raises(self, caplog): @patch("haystack.components.rankers.transformers_similarity.AutoModelForSequenceClassification.from_pretrained") def test_device_map_dict(self, mocked_automodel, _mocked_autotokenizer, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) ranker = TransformersSimilarityRanker("model", model_kwargs={"device_map": {"layer_1": 1, "classifier": "cpu"}}) class MockedModel: diff --git a/test/components/readers/test_extractive.py b/test/components/readers/test_extractive.py index aedfaa13bc..a2f658b79b 100644 --- a/test/components/readers/test_extractive.py +++ b/test/components/readers/test_extractive.py @@ -519,6 +519,7 @@ def __init__(self): @patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") def test_device_map_auto(mocked_automodel, _mocked_autotokenizer, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) reader = ExtractiveReader("deepset/roberta-base-squad2", model_kwargs={"device_map": "auto"}) auto_device = ComponentDevice.resolve_device(None) @@ -537,6 +538,7 @@ def __init__(self): @patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") def test_device_map_str(mocked_automodel, _mocked_autotokenizer, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) reader = ExtractiveReader("deepset/roberta-base-squad2", model_kwargs={"device_map": "cpu:0"}) class MockedModel: @@ -554,6 +556,7 @@ def __init__(self): @patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") def test_device_map_dict(mocked_automodel, _mocked_autotokenizer, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) reader = ExtractiveReader( "deepset/roberta-base-squad2", model_kwargs={"device_map": {"layer_1": 1, "classifier": "cpu"}} ) diff --git a/test/components/routers/test_transformers_text_router.py b/test/components/routers/test_transformers_text_router.py index 8a0dca8d63..67ec163524 100644 --- a/test/components/routers/test_transformers_text_router.py +++ b/test/components/routers/test_transformers_text_router.py @@ -54,6 +54,7 @@ def test_to_dict_with_cpu_device(self, mock_auto_config_from_pretrained): def test_from_dict(self, mock_auto_config_from_pretrained, monkeypatch): mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.routers.transformers_text_router.TransformersTextRouter", "init_parameters": { @@ -84,6 +85,7 @@ def test_from_dict(self, mock_auto_config_from_pretrained, monkeypatch): def test_from_dict_no_default_parameters(self, mock_auto_config_from_pretrained, monkeypatch): mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.routers.transformers_text_router.TransformersTextRouter", "init_parameters": {"model": "papluca/xlm-roberta-base-language-detection"}, @@ -105,6 +107,7 @@ def test_from_dict_no_default_parameters(self, mock_auto_config_from_pretrained, def test_from_dict_with_cpu_device(self, mock_auto_config_from_pretrained, monkeypatch): mock_auto_config_from_pretrained.return_value = MagicMock(label2id={"en": 0, "de": 1}) monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.routers.transformers_text_router.TransformersTextRouter", "init_parameters": { diff --git a/test/components/routers/test_zero_shot_text_router.py b/test/components/routers/test_zero_shot_text_router.py index 8e9759f361..3b931c39bb 100644 --- a/test/components/routers/test_zero_shot_text_router.py +++ b/test/components/routers/test_zero_shot_text_router.py @@ -28,6 +28,7 @@ def test_to_dict(self): def test_from_dict(self, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.routers.zero_shot_text_router.TransformersZeroShotTextRouter", "init_parameters": { @@ -56,6 +57,7 @@ def test_from_dict(self, monkeypatch): def test_from_dict_no_default_parameters(self, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) data = { "type": "haystack.components.routers.zero_shot_text_router.TransformersZeroShotTextRouter", "init_parameters": {"labels": ["query", "passage"]}, From 7b6e9d21aab21bca55eb2fc7d2095bb01be3871f Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 13 Dec 2024 10:30:16 +0100 Subject: [PATCH 03/12] adaptations --- haystack/components/generators/chat/hugging_face_api.py | 7 +------ haystack/components/generators/openai_utils.py | 9 ++------- test/components/generators/chat/test_hugging_face_api.py | 7 ------- test/components/generators/test_openai_utils.py | 7 ------- test/core/pipeline/features/test_run.py | 6 +++--- 5 files changed, 6 insertions(+), 30 deletions(-) diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index d4ecd53f10..968eb635d0 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -25,13 +25,8 @@ def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]: :returns: A dictionary with the following keys: - `role` - `content` - - `name` (optional) """ - formatted_msg = {"role": message.role.value, "content": message.content} - if message.name: - formatted_msg["name"] = message.name - - return formatted_msg + return {"role": message.role.value, "content": message.text} @component diff --git a/haystack/components/generators/openai_utils.py b/haystack/components/generators/openai_utils.py index 5b1838c386..ab6d5e7b1d 100644 --- a/haystack/components/generators/openai_utils.py +++ b/haystack/components/generators/openai_utils.py @@ -13,16 +13,11 @@ def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, str]: See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details. - :returns: A dictionary with the following key: + :returns: A dictionary with the following keys: - `role` - `content` - - `name` (optional) """ if message.text is None: raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}") - openai_msg = {"role": message.role.value, "content": message.text} - if message.name: - openai_msg["name"] = message.name - - return openai_msg + return {"role": message.role.value, "content": message.text} diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index 3d7fd617c0..e60ec863ab 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -68,13 +68,6 @@ def test_convert_message_to_hfapi_format(): message = ChatMessage.from_user("I have a question") assert _convert_message_to_hfapi_format(message) == {"role": "user", "content": "I have a question"} - message = ChatMessage.from_function("Function call", "function_name") - assert _convert_message_to_hfapi_format(message) == { - "role": "function", - "content": "Function call", - "name": "function_name", - } - class TestHuggingFaceAPIGenerator: def test_init_invalid_api_type(self): diff --git a/test/components/generators/test_openai_utils.py b/test/components/generators/test_openai_utils.py index 226b32f811..916a3e3d70 100644 --- a/test/components/generators/test_openai_utils.py +++ b/test/components/generators/test_openai_utils.py @@ -14,10 +14,3 @@ def test_convert_message_to_openai_format(): message = ChatMessage.from_user("I have a question") assert _convert_message_to_openai_format(message) == {"role": "user", "content": "I have a question"} - - message = ChatMessage.from_function("Function call", "function_name") - assert _convert_message_to_openai_format(message) == { - "role": "function", - "content": "Function call", - "name": "function_name", - } diff --git a/test/core/pipeline/features/test_run.py b/test/core/pipeline/features/test_run.py index d7001a0187..8f07dfec99 100644 --- a/test/core/pipeline/features/test_run.py +++ b/test/core/pipeline/features/test_run.py @@ -1657,7 +1657,7 @@ def run(self, query: str): class ToolExtractor: @component.output_types(output=List[str]) def run(self, messages: List[ChatMessage]): - prompt: str = messages[-1].content + prompt: str = messages[-1].text lines = prompt.strip().split("\n") for line in reversed(lines): pattern = r"Action:\s*(\w+)\[(.*?)\]" @@ -1678,14 +1678,14 @@ def __init__(self, suffix: str = ""): @component.output_types(output=List[ChatMessage]) def run(self, replies: List[ChatMessage], current_prompt: List[ChatMessage]): - content = current_prompt[-1].content + replies[-1].content + self._suffix + content = current_prompt[-1].text + replies[-1].text + self._suffix return {"output": [ChatMessage.from_user(content)]} @component class SearchOutputAdapter: @component.output_types(output=List[ChatMessage]) def run(self, replies: List[ChatMessage]): - content = f"Observation: {replies[-1].content}\n" + content = f"Observation: {replies[-1].text}\n" return {"output": [ChatMessage.from_assistant(content)]} pipeline.add_component("prompt_concatenator_after_action", PromptConcatenator()) From c462ddcbfe670ef2b00169a35602366a3e650e6a Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 13 Dec 2024 12:25:07 +0100 Subject: [PATCH 04/12] progress --- haystack/dataclasses/chat_message.py | 75 +++++++++++++++++++++++---- test/dataclasses/test_chat_message.py | 34 ++++++++++++ 2 files changed, 99 insertions(+), 10 deletions(-) diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index e0a437513f..292a9a9078 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -25,6 +25,9 @@ class ChatRole(str, Enum): #: The tool role. A message from a tool contains the result of a Tool invocation. TOOL = "tool" + #: The function role. Deprecated in favor of `TOOL`. + FUNCTION = "function" + @staticmethod def from_str(string: str) -> "ChatRole": """ @@ -94,6 +97,48 @@ class ChatMessage: _content: Sequence[ChatMessageContentT] _meta: Dict[str, Any] = field(default_factory=dict, hash=False) + def __new__(cls, *args, **kwargs): + """ + This method is reimplemented to make the changes to the `ChatMessage` dataclass more visible. + """ + + general_msg = ( + "Use the `from_assistant`, `from_user`, `from_system`, and `from_tool` class methods to create a " + "ChatMessage. Head over to the documentation for more information about the new API and how to migrate:" + " https://docs.haystack.deepset.ai/docs/data-classes#chatmessage" + ) + + if "role" in kwargs or "content" in kwargs or "meta" in kwargs: + raise TypeError( + "The `role`, `content`, and `meta` parameters of `ChatMessage` have been removed. " f"{general_msg}" + ) + + if len(args) > 1 and not isinstance(args[1], (TextContent, ToolCall, ToolCallResult)): + raise TypeError( + "The `content` parameter of `ChatMessage` must be a `ChatMessageContentT` instance. " f"{general_msg}" + ) + + return super(ChatMessage, cls).__new__(cls) + + def __post_init__(self): + if self._role == ChatRole.FUNCTION: + msg = "The `FUNCTION` role has been deprecated in favor of `TOOL` and will be removed in 2.10.0. " + warnings.warn(msg, DeprecationWarning) + + def __getattribute__(self, name): + """ + This method is reimplemented to make the `content` attribute removal more visible. + """ + if name == "content": + msg = ( + "The `content` attribute of `ChatMessage` has been removed. " + "Use the `text` property to access the textual value. " + "Head over to the documentation for more information: " + "https://docs.haystack.deepset.ai/docs/data-classes#chatmessage" + ) + raise AttributeError(msg) + return object.__getattribute__(self, name) + def __len__(self): return len(self._content) @@ -170,16 +215,6 @@ def is_from(self, role: Union[ChatRole, str]) -> bool: role = ChatRole.from_str(role) return self._role == role - def __getattribute__(self, name): - # this method is reimplemented to warn about the deprecation of the `content` attribute - if name == "content": - msg = ( - "The `content` attribute of `ChatMessage` will be removed in Haystack 2.9.0. " - "Use the `text` property to access the textual value." - ) - warnings.warn(msg, DeprecationWarning) - return object.__getattribute__(self, name) - @classmethod def from_user(cls, text: str, meta: Optional[Dict[str, Any]] = None) -> "ChatMessage": """ @@ -244,6 +279,26 @@ def from_tool( _meta=meta or {}, ) + @classmethod + def from_function(cls, content: str, name: str) -> "ChatMessage": + """ + Create a message from a function call. Deprecated in favor of `from_tool`. + + :param content: The text content of the message. + :param name: The name of the function being called. + :returns: A new ChatMessage instance. + """ + msg = ( + "The `from_function` method is deprecated and will be removed in version 2.10.0. " + "Its behavior has changed: it now attempts to convert legacy function messages to tool messages. " + "This conversion is not guaranteed to succeed in all scenarios. " + "Please migrate to `ChatMessage.from_tool` and carefully verify the results if you " + "continue to use this method." + ) + warnings.warn(msg) + + return cls.from_tool(content, ToolCall(id=None, tool_name=name, arguments={}), error=False) + def to_dict(self) -> Dict[str, Any]: """ Converts ChatMessage into a dictionary. diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index c3357882b4..2f0dc0e56d 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -135,6 +135,19 @@ def test_mixed_content(): assert message.tool_call == content[1] +def test_from_function(): + # check warning is raised + with pytest.warns(): + message = ChatMessage.from_function("Result of function invocation", "my_function") + + assert message.role == ChatRole.TOOL + assert message.tool_call_result == ToolCallResult( + result="Result of function invocation", + origin=ToolCall(id=None, tool_name="my_function", arguments={}), + error=False, + ) + + def test_serde(): # the following message is created just for testing purposes and does not make sense in a real use case @@ -188,6 +201,27 @@ def test_from_dict_with_invalid_content_type(): ChatMessage.from_dict(data) +def test_chat_message_content_attribute_removed(): + message = ChatMessage.from_user(text="This is a message") + with pytest.raises(AttributeError): + message.content + + +def test_chat_message_init_parameters_removed(): + with pytest.raises(TypeError): + ChatMessage(role="irrelevant", content="This is a message") + + +def test_chat_message_init_content_parameter_type(): + with pytest.raises(TypeError): + ChatMessage(ChatRole.USER, "This is a message") + + +def test_chat_message_function_role_deprecated(): + with pytest.warns(DeprecationWarning): + ChatMessage(ChatRole.FUNCTION, TextContent("This is a message")) + + @pytest.mark.integration def test_apply_chat_templating_on_chat_message(): messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")] From 873ae4f53221645b21e0615334b35c8f219a0fab Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 13 Dec 2024 16:35:52 +0100 Subject: [PATCH 05/12] fix type --- haystack/components/generators/chat/hugging_face_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 968eb635d0..8711a9175a 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -26,7 +26,7 @@ def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]: - `role` - `content` """ - return {"role": message.role.value, "content": message.text} + return {"role": message.role.value, "content": message.text or ""} @component From fe6c4c84e80ac8fcb351bb57b5ee258104287d36 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 13 Dec 2024 16:37:35 +0100 Subject: [PATCH 06/12] import sorting --- haystack/dataclasses/__init__.py | 2 +- haystack/dataclasses/chat_message.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack/dataclasses/__init__.py b/haystack/dataclasses/__init__.py index bcdd6acdd7..91e8f0408f 100644 --- a/haystack/dataclasses/__init__.py +++ b/haystack/dataclasses/__init__.py @@ -4,7 +4,7 @@ from haystack.dataclasses.answer import Answer, ExtractedAnswer, GeneratedAnswer from haystack.dataclasses.byte_stream import ByteStream -from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall, ToolCallResult, TextContent +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent, ToolCall, ToolCallResult from haystack.dataclasses.document import Document from haystack.dataclasses.sparse_embedding import SparseEmbedding from haystack.dataclasses.streaming_chunk import StreamingChunk diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index 292a9a9078..a28e026786 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -5,7 +5,7 @@ import warnings from dataclasses import asdict, dataclass, field from enum import Enum -from typing import Any, Dict, Optional, Union, List, Sequence +from typing import Any, Dict, List, Optional, Sequence, Union class ChatRole(str, Enum): From 1a5b46c3b86d1d6d461f6bdee849ada869bb18d5 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 16 Dec 2024 12:10:46 +0100 Subject: [PATCH 07/12] more control on deserialization --- haystack/dataclasses/chat_message.py | 12 ++++++++++-- test/dataclasses/test_chat_message.py | 3 +++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index a28e026786..4d84c7c00c 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -108,9 +108,10 @@ def __new__(cls, *args, **kwargs): " https://docs.haystack.deepset.ai/docs/data-classes#chatmessage" ) - if "role" in kwargs or "content" in kwargs or "meta" in kwargs: + if "role" in kwargs or "content" in kwargs or "meta" in kwargs or "name" in kwargs: raise TypeError( - "The `role`, `content`, and `meta` parameters of `ChatMessage` have been removed. " f"{general_msg}" + "The `role`, `content`, `meta`, and `name` parameters of `ChatMessage` have been removed. " + f"{general_msg}" ) if len(args) > 1 and not isinstance(args[1], (TextContent, ToolCall, ToolCallResult)): @@ -334,6 +335,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage": :returns: The created object. """ + if "role" in data or "content" in data or "meta" in data or "name" in data: + raise TypeError( + "The `role`, `content`, `meta`, and `name` parameters of `ChatMessage` have been removed. " + "Head over to the documentation for more information about the new API and how to migrate: " + "https://docs.haystack.deepset.ai/docs/data-classes#chatmessage" + ) + data["_role"] = ChatRole(data["_role"]) content: List[ChatMessageContentT] = [] diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index 2f0dc0e56d..c8a30e0610 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -201,6 +201,9 @@ def test_from_dict_with_invalid_content_type(): ChatMessage.from_dict(data) +# def test_chat_message_from_dict_with_invalid_content_type(): + + def test_chat_message_content_attribute_removed(): message = ChatMessage.from_user(text="This is a message") with pytest.raises(AttributeError): From e3f4c89647df9e895275734f9f71d0f4c05074c0 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 16 Dec 2024 18:35:25 +0100 Subject: [PATCH 08/12] release note --- .../new-chatmessage-7f47d5bdeb6ad6f5.yaml | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 releasenotes/notes/new-chatmessage-7f47d5bdeb6ad6f5.yaml diff --git a/releasenotes/notes/new-chatmessage-7f47d5bdeb6ad6f5.yaml b/releasenotes/notes/new-chatmessage-7f47d5bdeb6ad6f5.yaml new file mode 100644 index 0000000000..b9e590e590 --- /dev/null +++ b/releasenotes/notes/new-chatmessage-7f47d5bdeb6ad6f5.yaml @@ -0,0 +1,23 @@ +--- +highlights: > + We are introducing a refactored ChatMessage dataclass. It is more flexible, future-proof, and compatible with + different types of content: text, tool calls, tool calls results. + For information about the new API and how to migrate, see the documentation: + https://docs.haystack.deepset.ai/docs/data-classes#chatmessage +upgrade: + - | + The refactoring of the ChatMessage dataclass includes some breaking changes, involving ChatMessage creation and + accessing attributes. If you have a Pipeline containing a ChatPromptBuilder, serialized using Haystack<2.9.0, + deserialization may break. + For detailed information about the changes and how to migrate, see the documentation: + https://docs.haystack.deepset.ai/docs/data-classes#chatmessage +features: + - | + Changed the ChatMessage dataclass to support different types of content, including tool calls, and tool call + results. +deprecations: + - | + The function role and ChatMessage.from_function class method have been deprecated and will be removed in + Haystack 2.10.0. ChatMessage.from_function also attempts to produce a valid tool message. + For more information, see the documentation: + https://docs.haystack.deepset.ai/docs/data-classes#chatmessage From 180d0f3849d5dc1bf77752d9a251b2da6b081092 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 16 Dec 2024 18:44:01 +0100 Subject: [PATCH 09/12] improvements --- haystack/dataclasses/chat_message.py | 24 +++++++++++++++--------- test/dataclasses/test_chat_message.py | 4 +++- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index 4d84c7c00c..0e937fa600 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -7,6 +7,8 @@ from enum import Enum from typing import Any, Dict, List, Optional, Sequence, Union +LEGACY_INIT_PARAMETERS = {"role", "content", "meta", "name"} + class ChatRole(str, Enum): """ @@ -104,19 +106,22 @@ def __new__(cls, *args, **kwargs): general_msg = ( "Use the `from_assistant`, `from_user`, `from_system`, and `from_tool` class methods to create a " - "ChatMessage. Head over to the documentation for more information about the new API and how to migrate:" + "ChatMessage. For more information about the new API and how to migrate, see the documentation:" " https://docs.haystack.deepset.ai/docs/data-classes#chatmessage" ) - if "role" in kwargs or "content" in kwargs or "meta" in kwargs or "name" in kwargs: + if any(param in kwargs for param in LEGACY_INIT_PARAMETERS): raise TypeError( - "The `role`, `content`, `meta`, and `name` parameters of `ChatMessage` have been removed. " + "The `role`, `content`, `meta`, and `name` init parameters of `ChatMessage` have been removed. " f"{general_msg}" ) - if len(args) > 1 and not isinstance(args[1], (TextContent, ToolCall, ToolCallResult)): + allowed_content_types = (TextContent, ToolCall, ToolCallResult) + if len(args) > 1 and not isinstance(args[1], allowed_content_types): raise TypeError( - "The `content` parameter of `ChatMessage` must be a `ChatMessageContentT` instance. " f"{general_msg}" + "The `_content` parameter of `ChatMessage` must be a one of the following types: " + f"{', '.join(t.__name__ for t in allowed_content_types)}. " + f"{general_msg}" ) return super(ChatMessage, cls).__new__(cls) @@ -130,11 +135,12 @@ def __getattribute__(self, name): """ This method is reimplemented to make the `content` attribute removal more visible. """ + if name == "content": msg = ( "The `content` attribute of `ChatMessage` has been removed. " "Use the `text` property to access the textual value. " - "Head over to the documentation for more information: " + "For more information about the new API and how to migrate, see the documentation: " "https://docs.haystack.deepset.ai/docs/data-classes#chatmessage" ) raise AttributeError(msg) @@ -335,10 +341,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage": :returns: The created object. """ - if "role" in data or "content" in data or "meta" in data or "name" in data: + if any(param in data for param in LEGACY_INIT_PARAMETERS): raise TypeError( - "The `role`, `content`, `meta`, and `name` parameters of `ChatMessage` have been removed. " - "Head over to the documentation for more information about the new API and how to migrate: " + "The `role`, `content`, `meta`, and `name` init parameters of `ChatMessage` have been removed. " + "For more information about the new API and how to migrate, see the documentation: " "https://docs.haystack.deepset.ai/docs/data-classes#chatmessage" ) diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index c8a30e0610..2bfb0635dc 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -201,7 +201,9 @@ def test_from_dict_with_invalid_content_type(): ChatMessage.from_dict(data) -# def test_chat_message_from_dict_with_invalid_content_type(): +def test_from_dict_with_legacy_init_parameters(): + with pytest.raises(TypeError): + ChatMessage.from_dict({"role": "user", "content": "This is a message"}) def test_chat_message_content_attribute_removed(): From 328cebd98696257041933c4973da20157f8d4812 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 17 Dec 2024 10:45:15 +0100 Subject: [PATCH 10/12] support name field --- haystack/dataclasses/chat_message.py | 24 ++++++++++++++++++------ test/dataclasses/test_chat_message.py | 14 +++++++++++++- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index 0e937fa600..db64fb700d 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -97,6 +97,7 @@ class ChatMessage: _role: ChatRole _content: Sequence[ChatMessageContentT] + _name: Optional[str] = None _meta: Dict[str, Any] = field(default_factory=dict, hash=False) def __new__(cls, *args, **kwargs): @@ -163,6 +164,13 @@ def meta(self) -> Dict[str, Any]: """ return self._meta + @property + def name(self) -> Optional[str]: + """ + Returns the name associated with the message. + """ + return self._name + @property def texts(self) -> List[str]: """ @@ -223,32 +231,35 @@ def is_from(self, role: Union[ChatRole, str]) -> bool: return self._role == role @classmethod - def from_user(cls, text: str, meta: Optional[Dict[str, Any]] = None) -> "ChatMessage": + def from_user(cls, text: str, meta: Optional[Dict[str, Any]] = None, name: Optional[str] = None) -> "ChatMessage": """ Create a message from the user. :param text: The text content of the message. :param meta: Additional metadata associated with the message. + :param name: An optional name for the participant. This field is only supported by OpenAI. :returns: A new ChatMessage instance. """ - return cls(_role=ChatRole.USER, _content=[TextContent(text=text)], _meta=meta or {}) + return cls(_role=ChatRole.USER, _content=[TextContent(text=text)], _meta=meta or {}, _name=name) @classmethod - def from_system(cls, text: str, meta: Optional[Dict[str, Any]] = None) -> "ChatMessage": + def from_system(cls, text: str, meta: Optional[Dict[str, Any]] = None, name: Optional[str] = None) -> "ChatMessage": """ Create a message from the system. :param text: The text content of the message. :param meta: Additional metadata associated with the message. + :param name: An optional name for the participant. This field is only supported by OpenAI. :returns: A new ChatMessage instance. """ - return cls(_role=ChatRole.SYSTEM, _content=[TextContent(text=text)], _meta=meta or {}) + return cls(_role=ChatRole.SYSTEM, _content=[TextContent(text=text)], _meta=meta or {}, _name=name) @classmethod def from_assistant( cls, text: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, tool_calls: Optional[List[ToolCall]] = None, ) -> "ChatMessage": """ @@ -257,6 +268,7 @@ def from_assistant( :param text: The text content of the message. :param meta: Additional metadata associated with the message. :param tool_calls: The Tool calls to include in the message. + :param name: An optional name for the participant. This field is only supported by OpenAI. :returns: A new ChatMessage instance. """ content: List[ChatMessageContentT] = [] @@ -265,7 +277,7 @@ def from_assistant( if tool_calls: content.extend(tool_calls) - return cls(_role=ChatRole.ASSISTANT, _content=content, _meta=meta or {}) + return cls(_role=ChatRole.ASSISTANT, _content=content, _meta=meta or {}, _name=name) @classmethod def from_tool( @@ -316,7 +328,7 @@ def to_dict(self) -> Dict[str, Any]: serialized: Dict[str, Any] = {} serialized["_role"] = self._role.value serialized["_meta"] = self._meta - + serialized["_name"] = self._name content: List[Dict[str, Any]] = [] for part in self._content: if isinstance(part, TextContent): diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index 2bfb0635dc..832617e712 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -31,8 +31,9 @@ def test_from_assistant_with_valid_content(): text = "Hello, how can I assist you?" message = ChatMessage.from_assistant(text) - assert message._role == ChatRole.ASSISTANT + assert message.role == ChatRole.ASSISTANT assert message._content == [TextContent(text)] + assert message.name is None assert message.text == text assert message.texts == [text] @@ -69,6 +70,7 @@ def test_from_user_with_valid_content(): assert message.role == ChatRole.USER assert message._content == [TextContent(text)] + assert message.name is None assert message.text == text assert message.texts == [text] @@ -79,6 +81,15 @@ def test_from_user_with_valid_content(): assert not message.tool_call_result +def test_from_user_with_name(): + text = "I have a question." + message = ChatMessage.from_user(text=text, name="John") + + assert message.name == "John" + assert message.role == ChatRole.USER + assert message._content == [TextContent(text)] + + def test_from_system_with_valid_content(): text = "I have a question." message = ChatMessage.from_system(text=text) @@ -174,6 +185,7 @@ def test_serde(): }, ], "_role": "assistant", + "_name": None, "_meta": {"some": "info"}, } From b88daaea0785660b1a05bf32608dd02616c42766 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 17 Dec 2024 10:51:08 +0100 Subject: [PATCH 11/12] fix chatpromptbuilder test --- .../builders/test_chat_prompt_builder.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/test/components/builders/test_chat_prompt_builder.py b/test/components/builders/test_chat_prompt_builder.py index f981afa1b3..a8fb8bc5b8 100644 --- a/test/components/builders/test_chat_prompt_builder.py +++ b/test/components/builders/test_chat_prompt_builder.py @@ -531,8 +531,13 @@ def test_to_dict(self): "type": "haystack.components.builders.chat_prompt_builder.ChatPromptBuilder", "init_parameters": { "template": [ - {"_content": [{"text": "text and {var}"}], "_role": "user", "_meta": {}}, - {"_content": [{"text": "content {required_var}"}], "_role": "assistant", "_meta": {}}, + {"_content": [{"text": "text and {var}"}], "_role": "user", "_meta": {}, "_name": None}, + { + "_content": [{"text": "content {required_var}"}], + "_role": "assistant", + "_meta": {}, + "_name": None, + }, ], "variables": ["var", "required_var"], "required_variables": ["required_var"], @@ -545,8 +550,13 @@ def test_from_dict(self): "type": "haystack.components.builders.chat_prompt_builder.ChatPromptBuilder", "init_parameters": { "template": [ - {"_content": [{"text": "text and {var}"}], "_role": "user", "_meta": {}}, - {"_content": [{"text": "content {required_var}"}], "_role": "assistant", "_meta": {}}, + {"_content": [{"text": "text and {var}"}], "_role": "user", "_meta": {}, "_name": None}, + { + "_content": [{"text": "content {required_var}"}], + "_role": "assistant", + "_meta": {}, + "_name": None, + }, ], "variables": ["var", "required_var"], "required_variables": ["required_var"], From c663d44b00d4896e35c1cfc38cec4536da88150e Mon Sep 17 00:00:00 2001 From: Daria Fokina Date: Tue, 17 Dec 2024 16:36:31 +0100 Subject: [PATCH 12/12] Update chat_message.py --- haystack/dataclasses/chat_message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index db64fb700d..5aadb9f752 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -120,7 +120,7 @@ def __new__(cls, *args, **kwargs): allowed_content_types = (TextContent, ToolCall, ToolCallResult) if len(args) > 1 and not isinstance(args[1], allowed_content_types): raise TypeError( - "The `_content` parameter of `ChatMessage` must be a one of the following types: " + "The `_content` parameter of `ChatMessage` must be one of the following types: " f"{', '.join(t.__name__ for t in allowed_content_types)}. " f"{general_msg}" )