Skip to content

Commit

Permalink
fix: make llama.cpp Chat Generator compatible with new ChatMessage (#…
Browse files Browse the repository at this point in the history
…1254)

* progress

* remove vertex changes from this PR

* fix
  • Loading branch information
anakin87 authored Dec 19, 2024
1 parent 4c478af commit 1aba307
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Optional

from haystack import component
from haystack.dataclasses import ChatMessage, ChatRole
from haystack.dataclasses import ChatMessage
from llama_cpp import Llama
from llama_cpp.llama_tokenizer import LlamaHFTokenizer

Expand All @@ -21,6 +21,10 @@ def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, str]:
if message.name:
formatted_msg["name"] = message.name

if formatted_msg["role"] == "tool":
formatted_msg["name"] = message.tool_call_result.origin.tool_name
formatted_msg["content"] = message.tool_call_result.result

return formatted_msg


Expand Down Expand Up @@ -114,26 +118,31 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
formatted_messages = [_convert_message_to_llamacpp_format(msg) for msg in messages]

response = self.model.create_chat_completion(messages=formatted_messages, **updated_generation_kwargs)
replies = [
ChatMessage(
content=choice["message"]["content"],
role=ChatRole[choice["message"]["role"].upper()],
name=None,
meta={
"response_id": response["id"],
"model": response["model"],
"created": response["created"],
"index": choice["index"],
"finish_reason": choice["finish_reason"],
"usage": response["usage"],
},
)
for choice in response["choices"]
]

for reply, choice in zip(replies, response["choices"]):

replies = []

for choice in response["choices"]:
meta = {
"response_id": response["id"],
"model": response["model"],
"created": response["created"],
"index": choice["index"],
"finish_reason": choice["finish_reason"],
"usage": response["usage"],
}

name = None
tool_calls = choice.get("message", {}).get("tool_calls", [])
if tool_calls:
reply.meta["tool_calls"] = tool_calls
reply.name = tool_calls[0]["function"]["name"] if tool_calls else None
meta["tool_calls"] = tool_calls
name = tool_calls[0]["function"]["name"]

reply = ChatMessage.from_assistant(choice["message"]["content"], meta=meta)
if name:
if hasattr(reply, "_name"):
reply._name = name # new ChatMessage
elif hasattr(reply, "name"):
reply.name = name # legacy ChatMessage
replies.append(reply)

return {"replies": replies}
10 changes: 5 additions & 5 deletions integrations/llama_cpp/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def test_convert_message_to_llamacpp_format():
assert _convert_message_to_llamacpp_format(message) == {"role": "user", "content": "I have a question"}

message = ChatMessage.from_function("Function call", "function_name")
assert _convert_message_to_llamacpp_format(message) == {
"role": "function",
"content": "Function call",
"name": "function_name",
}
converted_message = _convert_message_to_llamacpp_format(message)

assert converted_message["role"] in ("function", "tool")
assert converted_message["name"] == "function_name"
assert converted_message["content"] == "Function call"


class TestLlamaCppChatGenerator:
Expand Down

0 comments on commit 1aba307

Please sign in to comment.