From 13ece0f0dc8864faad269bc6a8fc9f32bace636d Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 18 Dec 2024 14:46:24 +0100 Subject: [PATCH 1/5] draft --- .../tests/test_cohere_chat_generator.py | 2 +- .../generators/google_ai/chat/gemini.py | 53 ++++++++++++++----- .../tests/generators/chat/test_chat_gemini.py | 19 +++---- 3 files changed, 52 insertions(+), 22 deletions(-) diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 09f3708eb..4aaa2da2b 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -27,7 +27,7 @@ def streaming_chunk(text: str): @pytest.fixture def chat_messages(): - return [ChatMessage.from_assistant(content="What's the capital of France")] + return [ChatMessage.from_assistant("What's the capital of France")] class TestCohereChatGenerator: diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 089b38b10..255625aae 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -231,9 +231,13 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: raise ValueError(msg) def _message_to_part(self, message: ChatMessage) -> Part: - if message.is_from(ChatRole.ASSISTANT) and message.name: + name = getattr(message, "name", None) + if name is None: + name = getattr(message, "_name", None) + + if message.is_from(ChatRole.ASSISTANT) and name: p = Part() - p.function_call.name = message.name + p.function_call.name = name p.function_call.args = {} for k, v in json.loads(message.text).items(): p.function_call.args[k] = v @@ -244,16 +248,27 @@ def _message_to_part(self, message: ChatMessage) -> Part: return p elif message.is_from(ChatRole.FUNCTION): p = Part() - p.function_response.name = message.name + p.function_response.name = name p.function_response.response = message.text return p + elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL): + print("********* HERE *********") + part = Part() + part.function_response.name = message.tool_call_result.origin.tool_name + part.function_response.response = message.tool_call_result.result + print(part) elif message.is_from(ChatRole.USER): return self._convert_part(message.text) def _message_to_content(self, message: ChatMessage) -> Content: - if message.is_from(ChatRole.ASSISTANT) and message.name: + # support both new and legacy ChatMessage + name = getattr(message, "name", None) + if name is None: + name = getattr(message, "_name", None) + + if message.is_from(ChatRole.ASSISTANT) and name: part = Part() - part.function_call.name = message.name + part.function_call.name = name part.function_call.args = {} for k, v in json.loads(message.text).items(): part.function_call.args[k] = v @@ -262,8 +277,14 @@ def _message_to_content(self, message: ChatMessage) -> Content: part.text = message.text elif message.is_from(ChatRole.FUNCTION): part = Part() - part.function_response.name = message.name + part.function_response.name = name part.function_response.response = message.text + elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL): + print("********* HERE *********") + part = Part() + part.function_response.name = message.tool_call_result.origin.tool_name + part.function_response.response = message.tool_call_result.result + print(part) elif message.is_from(ChatRole.USER): part = self._convert_part(message.text) else: @@ -291,9 +312,11 @@ def run( """ streaming_callback = streaming_callback or self._streaming_callback history = [self._message_to_content(m) for m in messages[:-1]] + print(history) session = self._model.start_chat(history=history) new_message = self._message_to_part(messages[-1]) + print(new_message) res = session.send_message( content=new_message, generation_config=self._generation_config, @@ -335,13 +358,16 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess for part in candidate.content.parts: if part.text != "": - replies.append(ChatMessage.from_assistant(content=part.text, meta=candidate_metadata)) + replies.append(ChatMessage.from_assistant(part.text, meta=candidate_metadata)) elif part.function_call: candidate_metadata["function_call"] = part.function_call new_message = ChatMessage.from_assistant( - content=json.dumps(dict(part.function_call.args)), meta=candidate_metadata + json.dumps(dict(part.function_call.args)), meta=candidate_metadata ) - new_message.name = part.function_call.name + try: + new_message.name = part.function_call.name + except AttributeError: + new_message._name = part.function_call.name replies.append(new_message) return replies @@ -364,12 +390,15 @@ def _get_stream_response( for part in candidate["content"]["parts"]: if "text" in part and part["text"] != "": content = part["text"] - replies.append(ChatMessage.from_assistant(content=content, meta=metadata)) + replies.append(ChatMessage.from_assistant(content, meta=metadata)) elif "function_call" in part and len(part["function_call"]) > 0: metadata["function_call"] = part["function_call"] content = json.dumps(dict(part["function_call"]["args"])) - new_message = ChatMessage.from_assistant(content=content, meta=metadata) - new_message.name = part["function_call"]["name"] + new_message = ChatMessage.from_assistant(content, meta=metadata) + try: + new_message.name = part["function_call"]["name"] + except AttributeError: + new_message._name = part["function_call"]["name"] replies.append(new_message) streaming_callback(StreamingChunk(content=content, meta=metadata)) diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index b8658a4dd..1ccf9e1d3 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -215,7 +215,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 tool = Tool(function_declarations=[get_current_weather_func]) gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool]) - messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] + messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 @@ -227,7 +227,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} weather = get_current_weather(**json.loads(chat_message.text)) - messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")] response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 @@ -260,7 +260,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 tool = Tool(function_declarations=[get_current_weather_func]) gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool], streaming_callback=streaming_callback) - messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] + messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 @@ -272,8 +272,9 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert "function_call" in chat_message.meta assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} - weather = get_current_weather(**json.loads(response["replies"][0].text)) - messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + weather = str(get_current_weather(**json.loads(response["replies"][0].text))) + messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")] + print(messages) response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 @@ -289,10 +290,10 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 def test_past_conversation(): gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro") messages = [ - ChatMessage.from_system(content="You are a knowledageable mathematician."), - ChatMessage.from_user(content="What is 2+2?"), - ChatMessage.from_assistant(content="It's an arithmetic operation."), - ChatMessage.from_user(content="Yeah, but what's the result?"), + ChatMessage.from_system("You are a knowledageable mathematician."), + ChatMessage.from_user("What is 2+2?"), + ChatMessage.from_assistant("It's an arithmetic operation."), + ChatMessage.from_user("Yeah, but what's the result?"), ] response = gemini_chat.run(messages=messages) assert "replies" in response From 01fcb808754f3e53d22e17ab3a9b8b64e68e076a Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 18 Dec 2024 15:56:31 +0100 Subject: [PATCH 2/5] improvements --- .../generators/google_ai/chat/gemini.py | 48 ++++++++----------- .../tests/generators/chat/test_chat_gemini.py | 3 +- 2 files changed, 22 insertions(+), 29 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 255625aae..ab99de20a 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -231,13 +231,9 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: raise ValueError(msg) def _message_to_part(self, message: ChatMessage) -> Part: - name = getattr(message, "name", None) - if name is None: - name = getattr(message, "_name", None) - - if message.is_from(ChatRole.ASSISTANT) and name: + if message.is_from(ChatRole.ASSISTANT) and message.name: p = Part() - p.function_call.name = name + p.function_call.name = message.name p.function_call.args = {} for k, v in json.loads(message.text).items(): p.function_call.args[k] = v @@ -248,27 +244,21 @@ def _message_to_part(self, message: ChatMessage) -> Part: return p elif message.is_from(ChatRole.FUNCTION): p = Part() - p.function_response.name = name + p.function_response.name = message.name p.function_response.response = message.text return p elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL): - print("********* HERE *********") - part = Part() - part.function_response.name = message.tool_call_result.origin.tool_name - part.function_response.response = message.tool_call_result.result - print(part) + p = Part() + p.function_response.name = message.tool_call_result.origin.tool_name + p.function_response.response = message.tool_call_result.result + return p elif message.is_from(ChatRole.USER): return self._convert_part(message.text) def _message_to_content(self, message: ChatMessage) -> Content: - # support both new and legacy ChatMessage - name = getattr(message, "name", None) - if name is None: - name = getattr(message, "_name", None) - - if message.is_from(ChatRole.ASSISTANT) and name: + if message.is_from(ChatRole.ASSISTANT) and message.name: part = Part() - part.function_call.name = name + part.function_call.name = message.name part.function_call.args = {} for k, v in json.loads(message.text).items(): part.function_call.args[k] = v @@ -277,20 +267,26 @@ def _message_to_content(self, message: ChatMessage) -> Content: part.text = message.text elif message.is_from(ChatRole.FUNCTION): part = Part() - part.function_response.name = name + part.function_response.name = message.name part.function_response.response = message.text + elif message.is_from(ChatRole.USER): + part = self._convert_part(message.text) elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL): - print("********* HERE *********") part = Part() part.function_response.name = message.tool_call_result.origin.tool_name - part.function_response.response = message.tool_call_result.result - print(part) + part.function_response.response = message.tool_call_result.result elif message.is_from(ChatRole.USER): part = self._convert_part(message.text) else: msg = f"Unsupported message role {message.role}" raise ValueError(msg) - role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model" + role = ( + "user" + if message.is_from(ChatRole.USER) + or message.is_from(ChatRole.FUNCTION) + or ("TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL)) + else "model" + ) return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage]) @@ -312,11 +308,9 @@ def run( """ streaming_callback = streaming_callback or self._streaming_callback history = [self._message_to_content(m) for m in messages[:-1]] - print(history) session = self._model.start_chat(history=history) new_message = self._message_to_part(messages[-1]) - print(new_message) res = session.send_message( content=new_message, generation_config=self._generation_config, @@ -395,7 +389,7 @@ def _get_stream_response( metadata["function_call"] = part["function_call"] content = json.dumps(dict(part["function_call"]["args"])) new_message = ChatMessage.from_assistant(content, meta=metadata) - try: + try: new_message.name = part["function_call"]["name"] except AttributeError: new_message._name = part["function_call"]["name"] diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 1ccf9e1d3..0683bf21a 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -272,9 +272,8 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert "function_call" in chat_message.meta assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} - weather = str(get_current_weather(**json.loads(response["replies"][0].text))) + weather = get_current_weather(**json.loads(chat_message.text)) messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")] - print(messages) response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 From 55fcdda32c6d6e5de99c3c00fb3769c5b2fb3c69 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 18 Dec 2024 16:01:44 +0100 Subject: [PATCH 3/5] small improvemtn --- .../components/generators/google_ai/chat/gemini.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index ab99de20a..ed64102b5 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -280,13 +280,14 @@ def _message_to_content(self, message: ChatMessage) -> Content: else: msg = f"Unsupported message role {message.role}" raise ValueError(msg) - role = ( - "user" - if message.is_from(ChatRole.USER) + + role = "model" + if ( + message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) or ("TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL)) - else "model" - ) + ): + role = "user" return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage]) From 1ac7091db2c2918bc799f5888b7d1597b1d4ce3c Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 18 Dec 2024 16:06:24 +0100 Subject: [PATCH 4/5] rm duplication --- .../components/generators/google_ai/chat/gemini.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index ed64102b5..71e95eaed 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -275,8 +275,6 @@ def _message_to_content(self, message: ChatMessage) -> Content: part = Part() part.function_response.name = message.tool_call_result.origin.tool_name part.function_response.response = message.tool_call_result.result - elif message.is_from(ChatRole.USER): - part = self._convert_part(message.text) else: msg = f"Unsupported message role {message.role}" raise ValueError(msg) From 8f5a6da50c0cb8b6197663b52099179516bd824a Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 18 Dec 2024 17:09:22 +0100 Subject: [PATCH 5/5] simplification --- .../components/generators/google_ai/chat/gemini.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 71e95eaed..69f168a6b 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -279,13 +279,9 @@ def _message_to_content(self, message: ChatMessage) -> Content: msg = f"Unsupported message role {message.role}" raise ValueError(msg) - role = "model" - if ( - message.is_from(ChatRole.USER) - or message.is_from(ChatRole.FUNCTION) - or ("TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL)) - ): - role = "user" + role = "user" + if message.is_from(ChatRole.ASSISTANT) or message.is_from(ChatRole.SYSTEM): + role = "model" return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage])