From 3fd6af5be4f0de56946e81762d0ddd488eced1f4 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 17 Sep 2024 13:39:31 +0200 Subject: [PATCH 01/10] Add new tests --- .../google_vertex/tests/chat/test_gemini.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index a1564b9f2..f6e0d6323 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -3,7 +3,7 @@ import pytest from haystack import Pipeline from haystack.components.builders import ChatPromptBuilder -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, StreamingChunk, ChatRole from vertexai.generative_models import ( Content, FunctionDeclaration, @@ -293,3 +293,25 @@ def test_serialization_deserialization_pipeline(): new_pipeline = Pipeline.from_dict(pipeline_dict) assert new_pipeline == pipeline + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_role_in_messages(mock_generative_model): + mock_model = Mock() + mock_candidate = Mock(content=Content(parts=[Part.from_text("This is a generated response.")], role="model")) + mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate]) + + mock_model.send_message.return_value = mock_response + mock_model.start_chat.return_value = mock_model + mock_generative_model.return_value = mock_model + + messages = [ + ChatMessage.from_system("You are a helpful assistant"), + ChatMessage.from_user("What's the capital of France?"), + ] + gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None) + response = gemini.run(messages=messages) + assert response["replies"][0].is_from(ChatRole.ASSISTANT) + messages += [response["replies"][0], ChatMessage.from_user("How big is this city?")] + + mock_model.send_message.assert_called_once() From 56ab69e04eff607a627841dcc8cde1687b436856 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 17 Sep 2024 19:00:19 +0200 Subject: [PATCH 02/10] Add check for roles --- .../generators/google_ai/chat/gemini.py | 24 ++++++++-- .../tests/generators/chat/test_chat_gemini.py | 48 ++++++++++++++----- .../google_vertex/tests/chat/test_gemini.py | 40 ++++++---------- 3 files changed, 70 insertions(+), 42 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index e859a29fd..198a8e216 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -336,11 +336,27 @@ def _get_stream_response( :param streaming_callback: The handler for the streaming response. :returns: The extracted response with the content of all streaming chunks. """ - responses = [] + responses: Union[List[str], List[ChatMessage]] = [] for chunk in stream: - content = chunk.text if len(chunk.parts) > 0 and "text" in chunk.parts[0] else "" - streaming_callback(StreamingChunk(content=content, meta=chunk.to_dict())) - responses.append(content) + for candidate in chunk.candidates: + for part in candidate.content.parts: + if part.text != "": + content = part.text + responses.append(content) + elif part.function_call is not None: + content = dict(part.function_call.args.items()) + responses.append( + ChatMessage( + content=dict(part.function_call.args.items()), + role=ChatRole.ASSISTANT, + name=part.function_call.name, + ) + ) + + streaming_callback(StreamingChunk(content=content, meta=chunk.to_dict())) + + if isinstance(responses[0], ChatMessage): + return responses combined_response = "".join(responses).lstrip() return [ChatMessage.from_assistant(content=combined_response)] diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 35ad8db14..69a192e1b 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -5,7 +5,7 @@ from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import FunctionDeclaration, HarmBlockThreshold, HarmCategory, Tool from haystack.dataclasses import StreamingChunk -from haystack.dataclasses.chat_message import ChatMessage +from haystack.dataclasses.chat_message import ChatMessage, ChatRole from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator @@ -215,14 +215,21 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 tool = Tool(function_declarations=[get_current_weather_func]) gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool]) messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] - res = gemini_chat.run(messages=messages) - assert len(res["replies"]) > 0 - - weather = get_current_weather(**res["replies"][0].content) - messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] - - res = gemini_chat.run(messages=messages) - assert len(res["replies"]) > 0 + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + chat_message = response["replies"][0] + assert chat_message.content + assert chat_message.is_from(ChatRole.ASSISTANT) + + weather = get_current_weather(**chat_message.content) + messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + chat_message = response["replies"][0] + assert chat_message.content + assert chat_message.is_from(ChatRole.ASSISTANT) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") @@ -247,10 +254,21 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 tool = Tool(function_declarations=[get_current_weather_func]) gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool], streaming_callback=streaming_callback) messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] - res = gemini_chat.run(messages=messages) - assert len(res["replies"]) > 0 + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + chat_message = response["replies"][0] + assert chat_message.is_from(ChatRole.ASSISTANT) assert streaming_callback_called + weather = get_current_weather(**chat_message.content) + messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + chat_message = response["replies"][-1] + assert chat_message.is_from(ChatRole.ASSISTANT) + @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") def test_past_conversation(): @@ -261,5 +279,9 @@ def test_past_conversation(): ChatMessage.from_assistant(content="It's an arithmetic operation."), ChatMessage.from_user(content="Yeah, but what's the result?"), ] - res = gemini_chat.run(messages=messages) - assert len(res["replies"]) > 0 + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + chat_message = response["replies"][0] + assert chat_message.content + assert chat_message.is_from(ChatRole.ASSISTANT) diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index f6e0d6323..62e41d850 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -3,7 +3,7 @@ import pytest from haystack import Pipeline from haystack.components.builders import ChatPromptBuilder -from haystack.dataclasses import ChatMessage, StreamingChunk, ChatRole +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from vertexai.generative_models import ( Content, FunctionDeclaration, @@ -249,9 +249,15 @@ def test_run(mock_generative_model): ChatMessage.from_user("What's the capital of France?"), ] gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None) - gemini.run(messages=messages) + response = gemini.run(messages=messages) mock_model.send_message.assert_called_once() + assert "replies" in response + assert len(response["replies"]) == 1 + + chat_message = response["replies"][0] + assert chat_message.content + assert chat_message.is_from(ChatRole.ASSISTANT) @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") @@ -275,10 +281,16 @@ def streaming_callback(chunk: StreamingChunk) -> None: ChatMessage.from_system("You are a helpful assistant"), ChatMessage.from_user("What's the capital of France?"), ] - gemini.run(messages=messages) + response = gemini.run(messages=messages) mock_model.send_message.assert_called_once() assert streaming_callback_called == ["First part", "Second part"] + assert "replies" in response + assert len(response["replies"]) == 1 + + chat_message = response["replies"][0] + assert chat_message.content + assert chat_message.is_from(ChatRole.ASSISTANT) def test_serialization_deserialization_pipeline(): @@ -293,25 +305,3 @@ def test_serialization_deserialization_pipeline(): new_pipeline = Pipeline.from_dict(pipeline_dict) assert new_pipeline == pipeline - - -@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") -def test_role_in_messages(mock_generative_model): - mock_model = Mock() - mock_candidate = Mock(content=Content(parts=[Part.from_text("This is a generated response.")], role="model")) - mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate]) - - mock_model.send_message.return_value = mock_response - mock_model.start_chat.return_value = mock_model - mock_generative_model.return_value = mock_model - - messages = [ - ChatMessage.from_system("You are a helpful assistant"), - ChatMessage.from_user("What's the capital of France?"), - ] - gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None) - response = gemini.run(messages=messages) - assert response["replies"][0].is_from(ChatRole.ASSISTANT) - messages += [response["replies"][0], ChatMessage.from_user("How big is this city?")] - - mock_model.send_message.assert_called_once() From 82f503565facb43cd08ef0d887253dfb0c7336f4 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 18 Sep 2024 14:49:54 +0200 Subject: [PATCH 03/10] Add metadata to chat responses --- .../generators/google_vertex/chat/gemini.py | 6 +- .../generators/google_vertex/chat/main.py | 45 +++++++++++++++ .../google_vertex/tests/chat/test_gemini.py | 57 +++++++++++++++++-- 3 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/main.py diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index e693c10f4..4b3dd0779 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -232,14 +232,18 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: replies = [] for candidate in response_body.candidates: for part in candidate.content.parts: + metadata=candidate.to_dict() + metadata.pop("content") if part._raw_part.text != "": - replies.append(ChatMessage.from_assistant(part.text)) + replies.append(ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT,name = None, meta=metadata)) elif part.function_call is not None: + replies.append( ChatMessage( content=dict(part.function_call.args.items()), role=ChatRole.ASSISTANT, name=part.function_call.name, + meta=metadata ) ) return replies diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/main.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/main.py new file mode 100644 index 000000000..acf8d782e --- /dev/null +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/main.py @@ -0,0 +1,45 @@ +from vertexai.generative_models import Tool, FunctionDeclaration + +get_current_weather_func = FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit", + ], + }, + }, + "required": ["location"], + }, +) +tool = Tool([get_current_weather_func]) + +def get_current_weather(location: str, unit: str = "celsius"): + return {"weather": "sunny", "temperature": 21.8, "unit": unit} + +from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator + + +gemini_chat = VertexAIGeminiChatGenerator(project_id="my-project-1487737228087", tools=[tool]) +from haystack.dataclasses import ChatMessage + + +messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] +res = gemini_chat.run(messages=messages) +print ("RESPONSE") +print (res) +print(res["replies"][0].content) + +weather = get_current_weather(**res["replies"][0].content) + +messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + +res = gemini_chat.run(messages=messages) +print (res) +print(res["replies"][0].content) diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index 62e41d850..eb8992f48 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -253,11 +253,10 @@ def test_run(mock_generative_model): mock_model.send_message.assert_called_once() assert "replies" in response - assert len(response["replies"]) == 1 + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) - chat_message = response["replies"][0] - assert chat_message.content - assert chat_message.is_from(ChatRole.ASSISTANT) + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") @@ -305,3 +304,53 @@ def test_serialization_deserialization_pipeline(): new_pipeline = Pipeline.from_dict(pipeline_dict) assert new_pipeline == pipeline + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_function_call_and_execute(mock_generative_model): + mock_model = Mock() + mock_candidate = Mock(content=Content(parts=[Part.from_text("This is a generated response.")], role="model")) + mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate]) + + mock_model.send_message.return_value = mock_response + mock_model.start_chat.return_value = mock_model + mock_generative_model.return_value = mock_model + + get_current_weather_func = FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit", + ], + }, + }, + "required": ["location"], + }, + ) + + def get_current_weather(location: str, unit: str = "celsius"): + return {"weather": "sunny", "temperature": 21.8, "unit": unit} + + + tool = Tool(function_declarations=[get_current_weather_func]) + messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] + gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, tools=[tool]) + + response = gemini.run(messages=messages) + assert "replies" in response + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + + assert len(response["replies"]) > 0 + print (response) + + first_reply = response["replies"][0] + assert "tool_calls" in first_reply.meta + tool_calls = first_reply.meta["tool_calls"] + + \ No newline at end of file From 1bdfdb32a8001991ce48a5499168c166f49f67b1 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 19 Sep 2024 14:35:13 +0200 Subject: [PATCH 04/10] Fix meta in chat --- .../generators/google_ai/chat/gemini.py | 29 ++++++--- .../tests/generators/chat/test_chat_gemini.py | 24 +++---- .../generators/google_vertex/chat/gemini.py | 12 ++-- .../generators/google_vertex/chat/main.py | 45 ------------- .../google_vertex/tests/chat/test_gemini.py | 63 ++----------------- 5 files changed, 43 insertions(+), 130 deletions(-) delete mode 100644 integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/main.py diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 198a8e216..6905e2c7e 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -312,16 +312,20 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess :returns: The extracted responses. """ replies = [] + metadata = response_body.to_dict() + [candidate.pop("content", None) for candidate in metadata["candidates"]] for candidate in response_body.candidates: for part in candidate.content.parts: if part.text != "": - replies.append(ChatMessage.from_assistant(part.text)) + replies.append(ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata)) elif part.function_call is not None: + metadata["function_call"] = part.function_call replies.append( ChatMessage( content=dict(part.function_call.args.items()), role=ChatRole.ASSISTANT, name=part.function_call.name, + meta=metadata, ) ) return replies @@ -336,27 +340,32 @@ def _get_stream_response( :param streaming_callback: The handler for the streaming response. :returns: The extracted response with the content of all streaming chunks. """ - responses: Union[List[str], List[ChatMessage]] = [] + replies: Union[List[str], List[ChatMessage]] = [] + metadata = stream.to_dict() + + for candidate in metadata.get("candidates", []): + candidate.pop("content", None) + for chunk in stream: for candidate in chunk.candidates: for part in candidate.content.parts: if part.text != "": - content = part.text - responses.append(content) + replies.append(part.text) elif part.function_call is not None: - content = dict(part.function_call.args.items()) - responses.append( + metadata["function_call"] = part.function_call + replies.append( ChatMessage( content=dict(part.function_call.args.items()), role=ChatRole.ASSISTANT, name=part.function_call.name, + meta=metadata, ) ) - streaming_callback(StreamingChunk(content=content, meta=chunk.to_dict())) + streaming_callback(StreamingChunk(content=part.text, meta=chunk.to_dict())) - if isinstance(responses[0], ChatMessage): - return responses + if isinstance(replies[0], ChatMessage): + return replies - combined_response = "".join(responses).lstrip() + combined_response = "".join(replies).lstrip() return [ChatMessage.from_assistant(content=combined_response)] diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 69a192e1b..6177ed359 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -207,7 +207,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 get_current_weather_func = FunctionDeclaration.from_function( get_current_weather, descriptions={ - "location": "The city and state, e.g. San Francisco, CA", + "location": "The city and state, e.g. San Francisco", "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", }, ) @@ -218,18 +218,22 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + chat_message = response["replies"][0] - assert chat_message.content - assert chat_message.is_from(ChatRole.ASSISTANT) + assert "function_call" in chat_message.meta + assert chat_message.content == {"location": "Berlin", "unit": "celsius"} weather = get_current_weather(**chat_message.content) messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + chat_message = response["replies"][0] + assert "function_call" not in chat_message.meta assert chat_message.content - assert chat_message.is_from(ChatRole.ASSISTANT) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") @@ -257,17 +261,15 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 - chat_message = response["replies"][0] - assert chat_message.is_from(ChatRole.ASSISTANT) + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) assert streaming_callback_called - weather = get_current_weather(**chat_message.content) + weather = get_current_weather(**response["replies"][0].content) messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 - chat_message = response["replies"][-1] - assert chat_message.is_from(ChatRole.ASSISTANT) + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") @@ -282,6 +284,4 @@ def test_past_conversation(): response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 - chat_message = response["replies"][0] - assert chat_message.content - assert chat_message.is_from(ChatRole.ASSISTANT) + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 4b3dd0779..e910346e1 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -231,19 +231,21 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: """ replies = [] for candidate in response_body.candidates: + metadata = candidate.to_dict() + metadata.pop("content") for part in candidate.content.parts: - metadata=candidate.to_dict() - metadata.pop("content") if part._raw_part.text != "": - replies.append(ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT,name = None, meta=metadata)) + replies.append( + ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata) + ) elif part.function_call is not None: - + metadata["function_call"] = part.function_call replies.append( ChatMessage( content=dict(part.function_call.args.items()), role=ChatRole.ASSISTANT, name=part.function_call.name, - meta=metadata + meta=metadata, ) ) return replies diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/main.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/main.py deleted file mode 100644 index acf8d782e..000000000 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/main.py +++ /dev/null @@ -1,45 +0,0 @@ -from vertexai.generative_models import Tool, FunctionDeclaration - -get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type": "object", - "properties": { - "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, -) -tool = Tool([get_current_weather_func]) - -def get_current_weather(location: str, unit: str = "celsius"): - return {"weather": "sunny", "temperature": 21.8, "unit": unit} - -from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator - - -gemini_chat = VertexAIGeminiChatGenerator(project_id="my-project-1487737228087", tools=[tool]) -from haystack.dataclasses import ChatMessage - - -messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] -res = gemini_chat.run(messages=messages) -print ("RESPONSE") -print (res) -print(res["replies"][0].content) - -weather = get_current_weather(**res["replies"][0].content) - -messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] - -res = gemini_chat.run(messages=messages) -print (res) -print(res["replies"][0].content) diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index eb8992f48..736702ce5 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -256,14 +256,12 @@ def test_run(mock_generative_model): assert len(response["replies"]) > 0 assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) - - @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") def test_run_with_streaming_callback(mock_generative_model): mock_model = Mock() mock_responses = iter( - [MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text="Second part")] + [MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text=" Second part")] ) mock_model.send_message.return_value = mock_responses @@ -283,13 +281,12 @@ def streaming_callback(chunk: StreamingChunk) -> None: response = gemini.run(messages=messages) mock_model.send_message.assert_called_once() - assert streaming_callback_called == ["First part", "Second part"] + assert streaming_callback_called == ["First part", " Second part"] assert "replies" in response - assert len(response["replies"]) == 1 + assert len(response["replies"]) > 0 - chat_message = response["replies"][0] - assert chat_message.content - assert chat_message.is_from(ChatRole.ASSISTANT) + assert response["replies"][0].content == "First part Second part" + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) def test_serialization_deserialization_pipeline(): @@ -304,53 +301,3 @@ def test_serialization_deserialization_pipeline(): new_pipeline = Pipeline.from_dict(pipeline_dict) assert new_pipeline == pipeline - -@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") -def test_function_call_and_execute(mock_generative_model): - mock_model = Mock() - mock_candidate = Mock(content=Content(parts=[Part.from_text("This is a generated response.")], role="model")) - mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate]) - - mock_model.send_message.return_value = mock_response - mock_model.start_chat.return_value = mock_model - mock_generative_model.return_value = mock_model - - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type": "object", - "properties": { - "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - - def get_current_weather(location: str, unit: str = "celsius"): - return {"weather": "sunny", "temperature": 21.8, "unit": unit} - - - tool = Tool(function_declarations=[get_current_weather_func]) - messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] - gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, tools=[tool]) - - response = gemini.run(messages=messages) - assert "replies" in response - assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) - - assert len(response["replies"]) > 0 - print (response) - - first_reply = response["replies"][0] - assert "tool_calls" in first_reply.meta - tool_calls = first_reply.meta["tool_calls"] - - \ No newline at end of file From fee4a7050bd38a42ccd8df0613f5cdfdc2a9817a Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 20 Sep 2024 01:20:17 +0200 Subject: [PATCH 05/10] Small fixes --- .../generators/google_ai/chat/gemini.py | 16 +++++---------- .../tests/generators/chat/test_chat_gemini.py | 2 +- .../generators/google_vertex/chat/gemini.py | 20 ++++++++++++++----- .../google_vertex/tests/chat/test_gemini.py | 2 -- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 6905e2c7e..dbc88f514 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -341,16 +341,15 @@ def _get_stream_response( :returns: The extracted response with the content of all streaming chunks. """ replies: Union[List[str], List[ChatMessage]] = [] - metadata = stream.to_dict() - - for candidate in metadata.get("candidates", []): - candidate.pop("content", None) for chunk in stream: + metadata = chunk.to_dict() for candidate in chunk.candidates: for part in candidate.content.parts: if part.text != "": - replies.append(part.text) + replies.append( + ChatMessage(content=part.text, role=ChatRole.ASSISTANT, meta=metadata, name=None) + ) elif part.function_call is not None: metadata["function_call"] = part.function_call replies.append( @@ -363,9 +362,4 @@ def _get_stream_response( ) streaming_callback(StreamingChunk(content=part.text, meta=chunk.to_dict())) - - if isinstance(replies[0], ChatMessage): - return replies - - combined_response = "".join(replies).lstrip() - return [ChatMessage.from_assistant(content=combined_response)] + return replies diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 6177ed359..77ffdadb2 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -269,7 +269,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 - assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + assert all(reply.role == ChatRole.SYSTEM for reply in response["replies"]) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index e910346e1..4265760de 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -232,7 +232,6 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: replies = [] for candidate in response_body.candidates: metadata = candidate.to_dict() - metadata.pop("content") for part in candidate.content.parts: if part._raw_part.text != "": replies.append( @@ -260,11 +259,22 @@ def _get_stream_response( :param streaming_callback: The handler for the streaming response. :returns: The extracted response with the content of all streaming chunks. """ - responses = [] + replies = [] for chunk in stream: + metadata = chunk.to_dict() streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict()) streaming_callback(streaming_chunk) - responses.append(streaming_chunk.content) - combined_response = "".join(responses).lstrip() - return [ChatMessage.from_assistant(content=combined_response)] + if chunk.text != "": + replies.append(ChatMessage(chunk.text, role=ChatRole.ASSISTANT, name=None, meta=metadata)) + elif chunk.function_call is not None: + metadata["function_call"] = chunk.function_call + replies.append( + ChatMessage( + content=dict(chunk.function_call.args.items()), + role=ChatRole.ASSISTANT, + name=chunk.function_call.name, + meta=metadata, + ) + ) + return replies diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index 736702ce5..e4e6fa487 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -284,8 +284,6 @@ def streaming_callback(chunk: StreamingChunk) -> None: assert streaming_callback_called == ["First part", " Second part"] assert "replies" in response assert len(response["replies"]) > 0 - - assert response["replies"][0].content == "First part Second part" assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) From 1f53cd944f9a7ef5385039110dc848dc817cf412 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 20 Sep 2024 13:24:37 +0200 Subject: [PATCH 06/10] Fixed error in vertex streaming --- .../tests/generators/chat/test_chat_gemini.py | 2 +- .../generators/google_vertex/chat/gemini.py | 35 +++++++++++-------- .../google_vertex/tests/chat/test_gemini.py | 12 +++---- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 77ffdadb2..6177ed359 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -269,7 +269,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 - assert all(reply.role == ChatRole.SYSTEM for reply in response["replies"]) + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 4265760de..3673ad4c7 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -237,7 +237,7 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: replies.append( ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata) ) - elif part.function_call is not None: + elif part.function_call: metadata["function_call"] = part.function_call replies.append( ChatMessage( @@ -260,21 +260,28 @@ def _get_stream_response( :returns: The extracted response with the content of all streaming chunks. """ replies = [] + + content: Union[str, Dict[Any, Any]] = "" for chunk in stream: metadata = chunk.to_dict() - streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict()) + for candidate in chunk.candidates: + for part in candidate.content.parts: + + if part._raw_part.text: + content = chunk.text + replies.append(ChatMessage(content, role=ChatRole.ASSISTANT, name=None, meta=metadata)) + elif part.function_call: + metadata["function_call"] = part.function_call + content = dict(part.function_call.args.items()) + replies.append( + ChatMessage( + content=content, + role=ChatRole.ASSISTANT, + name=part.function_call.name, + meta=metadata, + ) + ) + streaming_chunk = StreamingChunk(content=content, meta=chunk.to_dict()) streaming_callback(streaming_chunk) - if chunk.text != "": - replies.append(ChatMessage(chunk.text, role=ChatRole.ASSISTANT, name=None, meta=metadata)) - elif chunk.function_call is not None: - metadata["function_call"] = chunk.function_call - replies.append( - ChatMessage( - content=dict(chunk.function_call.args.items()), - role=ChatRole.ASSISTANT, - name=chunk.function_call.name, - meta=metadata, - ) - ) return replies diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index e4e6fa487..ab21008fb 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -261,17 +261,17 @@ def test_run(mock_generative_model): def test_run_with_streaming_callback(mock_generative_model): mock_model = Mock() mock_responses = iter( - [MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text=" Second part")] + [MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text="Second part")] ) - mock_model.send_message.return_value = mock_responses mock_model.start_chat.return_value = mock_model mock_generative_model.return_value = mock_model streaming_callback_called = [] - def streaming_callback(chunk: StreamingChunk) -> None: - streaming_callback_called.append(chunk.content) + def streaming_callback(_chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback) messages = [ @@ -279,12 +279,8 @@ def streaming_callback(chunk: StreamingChunk) -> None: ChatMessage.from_user("What's the capital of France?"), ] response = gemini.run(messages=messages) - mock_model.send_message.assert_called_once() - assert streaming_callback_called == ["First part", " Second part"] assert "replies" in response - assert len(response["replies"]) > 0 - assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) def test_serialization_deserialization_pipeline(): From 4ee566df8c89e0cb2b4df3c2bce3dc67c3d4b7e8 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 20 Sep 2024 16:20:21 +0200 Subject: [PATCH 07/10] Fix errors --- .../generators/google_ai/chat/gemini.py | 34 +++++++++++-------- .../tests/generators/chat/test_chat_gemini.py | 16 +++++++-- .../generators/google_vertex/chat/gemini.py | 15 ++++---- 3 files changed, 41 insertions(+), 24 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index dbc88f514..c4074ce66 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -311,21 +311,25 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess :param response_body: The response from Google AI request. :returns: The extracted responses. """ - replies = [] + replies: List[ChatMessage] = [] metadata = response_body.to_dict() - [candidate.pop("content", None) for candidate in metadata["candidates"]] - for candidate in response_body.candidates: + for idx, candidate in enumerate(response_body.candidates): + candidate_metadata = metadata["candidates"][idx] + candidate_metadata.pop("content", None) # we remove content from the metadata + for part in candidate.content.parts: if part.text != "": - replies.append(ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata)) - elif part.function_call is not None: - metadata["function_call"] = part.function_call + replies.append( + ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=candidate_metadata) + ) + elif part.function_call: + candidate_metadata["function_call"] = part.function_call replies.append( ChatMessage( content=dict(part.function_call.args.items()), role=ChatRole.ASSISTANT, name=part.function_call.name, - meta=metadata, + meta=candidate_metadata, ) ) return replies @@ -340,26 +344,26 @@ def _get_stream_response( :param streaming_callback: The handler for the streaming response. :returns: The extracted response with the content of all streaming chunks. """ - replies: Union[List[str], List[ChatMessage]] = [] - + replies: List[ChatMessage] = [] + content: Union[str, Dict[str, Any]] = "" for chunk in stream: - metadata = chunk.to_dict() + metadata = chunk.to_dict() # we store whole chunk as metadata in streaming calls for candidate in chunk.candidates: for part in candidate.content.parts: if part.text != "": - replies.append( - ChatMessage(content=part.text, role=ChatRole.ASSISTANT, meta=metadata, name=None) - ) + content = part.text + replies.append(ChatMessage(content=content, role=ChatRole.ASSISTANT, meta=metadata, name=None)) elif part.function_call is not None: metadata["function_call"] = part.function_call + content = dict(part.function_call.args.items()) replies.append( ChatMessage( - content=dict(part.function_call.args.items()), + content=content, role=ChatRole.ASSISTANT, name=part.function_call.name, meta=metadata, ) ) - streaming_callback(StreamingChunk(content=part.text, meta=chunk.to_dict())) + streaming_callback(StreamingChunk(content=content, meta=metadata)) return replies diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 6177ed359..b6dd40896 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -220,6 +220,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert len(response["replies"]) > 0 assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + # check the first response is a function call chat_message = response["replies"][0] assert "function_call" in chat_message.meta assert chat_message.content == {"location": "Berlin", "unit": "celsius"} @@ -231,9 +232,10 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert len(response["replies"]) > 0 assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + # check the second response is not a function call chat_message = response["replies"][0] assert "function_call" not in chat_message.meta - assert chat_message.content + assert isinstance(chat_message.content, str) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") @@ -250,7 +252,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 get_current_weather_func = FunctionDeclaration.from_function( get_current_weather, descriptions={ - "location": "The city and state, e.g. San Francisco, CA", + "location": "The city and state, e.g. San Francisco", "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", }, ) @@ -264,6 +266,11 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) assert streaming_callback_called + # check the first response is a function call + chat_message = response["replies"][0] + assert "function_call" in chat_message.meta + assert chat_message.content == {"location": "Berlin", "unit": "celsius"} + weather = get_current_weather(**response["replies"][0].content) messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] response = gemini_chat.run(messages=messages) @@ -271,6 +278,11 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert len(response["replies"]) > 0 assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + # check the second response is not a function call + chat_message = response["replies"][0] + assert "function_call" not in chat_message.meta + assert isinstance(chat_message.content, str) + @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") def test_past_conversation(): diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 3673ad4c7..e8de0b006 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -229,10 +229,12 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: :param response_body: The response from Vertex AI request. :returns: The extracted responses. """ - replies = [] + replies: List[ChatMessage] = [] for candidate in response_body.candidates: - metadata = candidate.to_dict() for part in candidate.content.parts: + # Remove content from metadata + metadata = part.to_dict() + metadata.pop("content", None) if part._raw_part.text != "": replies.append( ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata) @@ -259,14 +261,13 @@ def _get_stream_response( :param streaming_callback: The handler for the streaming response. :returns: The extracted response with the content of all streaming chunks. """ - replies = [] + replies: List[ChatMessage] = [] + content: Union[str, Dict[str, Any]] = "" - content: Union[str, Dict[Any, Any]] = "" for chunk in stream: - metadata = chunk.to_dict() + metadata = chunk.to_dict() # we store whole chunk as metadata for streaming for candidate in chunk.candidates: for part in candidate.content.parts: - if part._raw_part.text: content = chunk.text replies.append(ChatMessage(content, role=ChatRole.ASSISTANT, name=None, meta=metadata)) @@ -281,7 +282,7 @@ def _get_stream_response( meta=metadata, ) ) - streaming_chunk = StreamingChunk(content=content, meta=chunk.to_dict()) + streaming_chunk = StreamingChunk(content=content, meta=metadata) streaming_callback(streaming_chunk) return replies From b755b312b994a48ab25f84f5154c5aa3471b9109 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Mon, 23 Sep 2024 22:18:09 +0200 Subject: [PATCH 08/10] Fix for metadata --- .../components/generators/google_vertex/chat/gemini.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index e8de0b006..5fecfbde6 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -231,9 +231,9 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: """ replies: List[ChatMessage] = [] for candidate in response_body.candidates: + metadata = candidate.to_dict() for part in candidate.content.parts: # Remove content from metadata - metadata = part.to_dict() metadata.pop("content", None) if part._raw_part.text != "": replies.append( From 7e4debacfbf02c2f3570d203731d169d20e3def7 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 24 Sep 2024 16:34:46 +0200 Subject: [PATCH 09/10] Updates based on review --- .../components/generators/google_ai/chat/gemini.py | 4 ++-- .../google_ai/tests/generators/chat/test_chat_gemini.py | 4 ++-- .../components/generators/google_vertex/chat/gemini.py | 5 ++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index c4074ce66..56c84968b 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -345,8 +345,8 @@ def _get_stream_response( :returns: The extracted response with the content of all streaming chunks. """ replies: List[ChatMessage] = [] - content: Union[str, Dict[str, Any]] = "" for chunk in stream: + content: Union[str, Dict[str, Any]] = "" metadata = chunk.to_dict() # we store whole chunk as metadata in streaming calls for candidate in chunk.candidates: for part in candidate.content.parts: @@ -365,5 +365,5 @@ def _get_stream_response( ) ) - streaming_callback(StreamingChunk(content=content, meta=metadata)) + streaming_callback(StreamingChunk(content=content, meta=metadata)) return replies diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index b6dd40896..c4372db0d 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -207,7 +207,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 get_current_weather_func = FunctionDeclaration.from_function( get_current_weather, descriptions={ - "location": "The city and state, e.g. San Francisco", + "location": "The city, e.g. San Francisco", "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", }, ) @@ -252,7 +252,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 get_current_weather_func = FunctionDeclaration.from_function( get_current_weather, descriptions={ - "location": "The city and state, e.g. San Francisco", + "location": "The city, e.g. San Francisco", "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", }, ) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 5fecfbde6..fb5fd6f4c 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -262,9 +262,9 @@ def _get_stream_response( :returns: The extracted response with the content of all streaming chunks. """ replies: List[ChatMessage] = [] - content: Union[str, Dict[str, Any]] = "" for chunk in stream: + content: Union[str, Dict[str, Any]] = "" metadata = chunk.to_dict() # we store whole chunk as metadata for streaming for candidate in chunk.candidates: for part in candidate.content.parts: @@ -282,7 +282,6 @@ def _get_stream_response( meta=metadata, ) ) - streaming_chunk = StreamingChunk(content=content, meta=metadata) - streaming_callback(streaming_chunk) + streaming_callback(StreamingChunk(content=content, meta=metadata)) return replies From 09380d3a4e03b1a4fab6e25fba2b1e021c1bb049 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 24 Sep 2024 16:36:58 +0200 Subject: [PATCH 10/10] Small fix --- .../components/generators/google_vertex/chat/gemini.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index fb5fd6f4c..ac4c93228 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -282,6 +282,6 @@ def _get_stream_response( meta=metadata, ) ) - streaming_callback(StreamingChunk(content=content, meta=metadata)) + streaming_callback(StreamingChunk(content=content, meta=metadata)) return replies