From 58cb13522cfe9e334b478c346a1f85e238f51a17 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 19 Dec 2024 15:01:16 +0100 Subject: [PATCH] fix: make GoogleAI Chat Generator compatible with new `ChatMessage`; small fixes to Cohere tests (#1253) * draft * improvements * small improvemtn * rm duplication * simplification --- .../tests/test_cohere_chat_generator.py | 2 +- .../generators/google_ai/chat/gemini.py | 32 +++++++++++++++---- .../tests/generators/chat/test_chat_gemini.py | 18 +++++------ 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 09f3708eb..4aaa2da2b 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -27,7 +27,7 @@ def streaming_chunk(text: str): @pytest.fixture def chat_messages(): - return [ChatMessage.from_assistant(content="What's the capital of France")] + return [ChatMessage.from_assistant("What's the capital of France")] class TestCohereChatGenerator: diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 089b38b10..69f168a6b 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -247,6 +247,11 @@ def _message_to_part(self, message: ChatMessage) -> Part: p.function_response.name = message.name p.function_response.response = message.text return p + elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL): + p = Part() + p.function_response.name = message.tool_call_result.origin.tool_name + p.function_response.response = message.tool_call_result.result + return p elif message.is_from(ChatRole.USER): return self._convert_part(message.text) @@ -266,10 +271,17 @@ def _message_to_content(self, message: ChatMessage) -> Content: part.function_response.response = message.text elif message.is_from(ChatRole.USER): part = self._convert_part(message.text) + elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL): + part = Part() + part.function_response.name = message.tool_call_result.origin.tool_name + part.function_response.response = message.tool_call_result.result else: msg = f"Unsupported message role {message.role}" raise ValueError(msg) - role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model" + + role = "user" + if message.is_from(ChatRole.ASSISTANT) or message.is_from(ChatRole.SYSTEM): + role = "model" return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage]) @@ -335,13 +347,16 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess for part in candidate.content.parts: if part.text != "": - replies.append(ChatMessage.from_assistant(content=part.text, meta=candidate_metadata)) + replies.append(ChatMessage.from_assistant(part.text, meta=candidate_metadata)) elif part.function_call: candidate_metadata["function_call"] = part.function_call new_message = ChatMessage.from_assistant( - content=json.dumps(dict(part.function_call.args)), meta=candidate_metadata + json.dumps(dict(part.function_call.args)), meta=candidate_metadata ) - new_message.name = part.function_call.name + try: + new_message.name = part.function_call.name + except AttributeError: + new_message._name = part.function_call.name replies.append(new_message) return replies @@ -364,12 +379,15 @@ def _get_stream_response( for part in candidate["content"]["parts"]: if "text" in part and part["text"] != "": content = part["text"] - replies.append(ChatMessage.from_assistant(content=content, meta=metadata)) + replies.append(ChatMessage.from_assistant(content, meta=metadata)) elif "function_call" in part and len(part["function_call"]) > 0: metadata["function_call"] = part["function_call"] content = json.dumps(dict(part["function_call"]["args"])) - new_message = ChatMessage.from_assistant(content=content, meta=metadata) - new_message.name = part["function_call"]["name"] + new_message = ChatMessage.from_assistant(content, meta=metadata) + try: + new_message.name = part["function_call"]["name"] + except AttributeError: + new_message._name = part["function_call"]["name"] replies.append(new_message) streaming_callback(StreamingChunk(content=content, meta=metadata)) diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index b8658a4dd..0683bf21a 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -215,7 +215,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 tool = Tool(function_declarations=[get_current_weather_func]) gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool]) - messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] + messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 @@ -227,7 +227,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} weather = get_current_weather(**json.loads(chat_message.text)) - messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")] response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 @@ -260,7 +260,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 tool = Tool(function_declarations=[get_current_weather_func]) gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool], streaming_callback=streaming_callback) - messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] + messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 @@ -272,8 +272,8 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert "function_call" in chat_message.meta assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} - weather = get_current_weather(**json.loads(response["replies"][0].text)) - messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + weather = get_current_weather(**json.loads(chat_message.text)) + messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")] response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 @@ -289,10 +289,10 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 def test_past_conversation(): gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro") messages = [ - ChatMessage.from_system(content="You are a knowledageable mathematician."), - ChatMessage.from_user(content="What is 2+2?"), - ChatMessage.from_assistant(content="It's an arithmetic operation."), - ChatMessage.from_user(content="Yeah, but what's the result?"), + ChatMessage.from_system("You are a knowledageable mathematician."), + ChatMessage.from_user("What is 2+2?"), + ChatMessage.from_assistant("It's an arithmetic operation."), + ChatMessage.from_user("Yeah, but what's the result?"), ] response = gemini_chat.run(messages=messages) assert "replies" in response