diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py index 014dd7169..d2150f61f 100644 --- a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional from haystack import component -from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses import ChatMessage from llama_cpp import Llama from llama_cpp.llama_tokenizer import LlamaHFTokenizer @@ -21,6 +21,10 @@ def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, str]: if message.name: formatted_msg["name"] = message.name + if formatted_msg["role"] == "tool": + formatted_msg["name"] = message.tool_call_result.origin.tool_name + formatted_msg["content"] = message.tool_call_result.result + return formatted_msg @@ -114,26 +118,31 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, formatted_messages = [_convert_message_to_llamacpp_format(msg) for msg in messages] response = self.model.create_chat_completion(messages=formatted_messages, **updated_generation_kwargs) - replies = [ - ChatMessage( - content=choice["message"]["content"], - role=ChatRole[choice["message"]["role"].upper()], - name=None, - meta={ - "response_id": response["id"], - "model": response["model"], - "created": response["created"], - "index": choice["index"], - "finish_reason": choice["finish_reason"], - "usage": response["usage"], - }, - ) - for choice in response["choices"] - ] - - for reply, choice in zip(replies, response["choices"]): + + replies = [] + + for choice in response["choices"]: + meta = { + "response_id": response["id"], + "model": response["model"], + "created": response["created"], + "index": choice["index"], + "finish_reason": choice["finish_reason"], + "usage": response["usage"], + } + + name = None tool_calls = choice.get("message", {}).get("tool_calls", []) if tool_calls: - reply.meta["tool_calls"] = tool_calls - reply.name = tool_calls[0]["function"]["name"] if tool_calls else None + meta["tool_calls"] = tool_calls + name = tool_calls[0]["function"]["name"] + + reply = ChatMessage.from_assistant(choice["message"]["content"], meta=meta) + if name: + if hasattr(reply, "_name"): + reply._name = name # new ChatMessage + elif hasattr(reply, "name"): + reply.name = name # legacy ChatMessage + replies.append(reply) + return {"replies": replies} diff --git a/integrations/llama_cpp/tests/test_chat_generator.py b/integrations/llama_cpp/tests/test_chat_generator.py index 0ddd78c4f..87639f684 100644 --- a/integrations/llama_cpp/tests/test_chat_generator.py +++ b/integrations/llama_cpp/tests/test_chat_generator.py @@ -41,11 +41,11 @@ def test_convert_message_to_llamacpp_format(): assert _convert_message_to_llamacpp_format(message) == {"role": "user", "content": "I have a question"} message = ChatMessage.from_function("Function call", "function_name") - assert _convert_message_to_llamacpp_format(message) == { - "role": "function", - "content": "Function call", - "name": "function_name", - } + converted_message = _convert_message_to_llamacpp_format(message) + + assert converted_message["role"] in ("function", "tool") + assert converted_message["name"] == "function_name" + assert converted_message["content"] == "Function call" class TestLlamaCppChatGenerator: