From 943f8e50cc9dc30ce6a8fd867bcb802dad2c2e74 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 10 Dec 2024 18:15:37 +0100 Subject: [PATCH] fix: GoogleAI - fix the content type of `ChatMessage` `content` from function (#1241) * fix Gemini * avoid directly accessing role --- .../generators/google_ai/chat/gemini.py | 49 ++++++++++--------- .../tests/generators/chat/test_chat_gemini.py | 13 ++--- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index ef7d583be..089b38b10 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -1,3 +1,4 @@ +import json import logging from typing import Any, Callable, Dict, List, Optional, Union @@ -36,12 +37,12 @@ class GoogleAIGeminiChatGenerator: messages = [ChatMessage.from_user("What is the most interesting thing you know?")] res = gemini_chat.run(messages=messages) for reply in res["replies"]: - print(reply.content) + print(reply.text) messages += res["replies"] + [ChatMessage.from_user("Tell me more about it")] res = gemini_chat.run(messages=messages) for reply in res["replies"]: - print(reply.content) + print(reply.text) ``` @@ -85,14 +86,14 @@ def get_current_weather(location: str, unit: str = "celsius") -> str: gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", api_key=Secret.from_token(""), tools=[tool]) - messages = [ChatMessage.from_user(content = "What is the temperature in celsius in Berlin?")] + messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] res = gemini_chat.run(messages=messages) - weather = get_current_weather(**res["replies"][0].content) + weather = get_current_weather(**json.loads(res["replies"][0].text)) messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] res = gemini_chat.run(messages=messages) for reply in res["replies"]: - print(reply.content) + print(reply.text) ``` """ @@ -230,45 +231,45 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: raise ValueError(msg) def _message_to_part(self, message: ChatMessage) -> Part: - if message.role == ChatRole.ASSISTANT and message.name: + if message.is_from(ChatRole.ASSISTANT) and message.name: p = Part() p.function_call.name = message.name p.function_call.args = {} - for k, v in message.content.items(): + for k, v in json.loads(message.text).items(): p.function_call.args[k] = v return p - elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}: + elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): p = Part() - p.text = message.content + p.text = message.text return p - elif message.role == ChatRole.FUNCTION: + elif message.is_from(ChatRole.FUNCTION): p = Part() p.function_response.name = message.name - p.function_response.response = message.content + p.function_response.response = message.text return p - elif message.role == ChatRole.USER: - return self._convert_part(message.content) + elif message.is_from(ChatRole.USER): + return self._convert_part(message.text) def _message_to_content(self, message: ChatMessage) -> Content: - if message.role == ChatRole.ASSISTANT and message.name: + if message.is_from(ChatRole.ASSISTANT) and message.name: part = Part() part.function_call.name = message.name part.function_call.args = {} - for k, v in message.content.items(): + for k, v in json.loads(message.text).items(): part.function_call.args[k] = v - elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}: + elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): part = Part() - part.text = message.content - elif message.role == ChatRole.FUNCTION: + part.text = message.text + elif message.is_from(ChatRole.FUNCTION): part = Part() part.function_response.name = message.name - part.function_response.response = message.content - elif message.role == ChatRole.USER: - part = self._convert_part(message.content) + part.function_response.response = message.text + elif message.is_from(ChatRole.USER): + part = self._convert_part(message.text) else: msg = f"Unsupported message role {message.role}" raise ValueError(msg) - role = "user" if message.role in [ChatRole.USER, ChatRole.FUNCTION] else "model" + role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model" return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage]) @@ -338,7 +339,7 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess elif part.function_call: candidate_metadata["function_call"] = part.function_call new_message = ChatMessage.from_assistant( - content=dict(part.function_call.args.items()), meta=candidate_metadata + content=json.dumps(dict(part.function_call.args)), meta=candidate_metadata ) new_message.name = part.function_call.name replies.append(new_message) @@ -366,7 +367,7 @@ def _get_stream_response( replies.append(ChatMessage.from_assistant(content=content, meta=metadata)) elif "function_call" in part and len(part["function_call"]) > 0: metadata["function_call"] = part["function_call"] - content = part["function_call"]["args"] + content = json.dumps(dict(part["function_call"]["args"])) new_message = ChatMessage.from_assistant(content=content, meta=metadata) new_message.name = part["function_call"]["name"] replies.append(new_message) diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index cb42f0ff8..b8658a4dd 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -1,3 +1,4 @@ +import json import os from unittest.mock import patch @@ -223,9 +224,9 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 # check the first response is a function call chat_message = response["replies"][0] assert "function_call" in chat_message.meta - assert chat_message.content == {"location": "Berlin", "unit": "celsius"} + assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} - weather = get_current_weather(**chat_message.content) + weather = get_current_weather(**json.loads(chat_message.text)) messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] response = gemini_chat.run(messages=messages) assert "replies" in response @@ -235,7 +236,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 # check the second response is not a function call chat_message = response["replies"][0] assert "function_call" not in chat_message.meta - assert isinstance(chat_message.content, str) + assert isinstance(chat_message.text, str) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") @@ -269,9 +270,9 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 # check the first response is a function call chat_message = response["replies"][0] assert "function_call" in chat_message.meta - assert chat_message.content == {"location": "Berlin", "unit": "celsius"} + assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} - weather = get_current_weather(**response["replies"][0].content) + weather = get_current_weather(**json.loads(response["replies"][0].text)) messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] response = gemini_chat.run(messages=messages) assert "replies" in response @@ -281,7 +282,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 # check the second response is not a function call chat_message = response["replies"][0] assert "function_call" not in chat_message.meta - assert isinstance(chat_message.content, str) + assert isinstance(chat_message.text, str) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")