diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index dbc88f514..c4074ce66 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -311,21 +311,25 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess :param response_body: The response from Google AI request. :returns: The extracted responses. """ - replies = [] + replies: List[ChatMessage] = [] metadata = response_body.to_dict() - [candidate.pop("content", None) for candidate in metadata["candidates"]] - for candidate in response_body.candidates: + for idx, candidate in enumerate(response_body.candidates): + candidate_metadata = metadata["candidates"][idx] + candidate_metadata.pop("content", None) # we remove content from the metadata + for part in candidate.content.parts: if part.text != "": - replies.append(ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata)) - elif part.function_call is not None: - metadata["function_call"] = part.function_call + replies.append( + ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=candidate_metadata) + ) + elif part.function_call: + candidate_metadata["function_call"] = part.function_call replies.append( ChatMessage( content=dict(part.function_call.args.items()), role=ChatRole.ASSISTANT, name=part.function_call.name, - meta=metadata, + meta=candidate_metadata, ) ) return replies @@ -340,26 +344,26 @@ def _get_stream_response( :param streaming_callback: The handler for the streaming response. :returns: The extracted response with the content of all streaming chunks. """ - replies: Union[List[str], List[ChatMessage]] = [] - + replies: List[ChatMessage] = [] + content: Union[str, Dict[str, Any]] = "" for chunk in stream: - metadata = chunk.to_dict() + metadata = chunk.to_dict() # we store whole chunk as metadata in streaming calls for candidate in chunk.candidates: for part in candidate.content.parts: if part.text != "": - replies.append( - ChatMessage(content=part.text, role=ChatRole.ASSISTANT, meta=metadata, name=None) - ) + content = part.text + replies.append(ChatMessage(content=content, role=ChatRole.ASSISTANT, meta=metadata, name=None)) elif part.function_call is not None: metadata["function_call"] = part.function_call + content = dict(part.function_call.args.items()) replies.append( ChatMessage( - content=dict(part.function_call.args.items()), + content=content, role=ChatRole.ASSISTANT, name=part.function_call.name, meta=metadata, ) ) - streaming_callback(StreamingChunk(content=part.text, meta=chunk.to_dict())) + streaming_callback(StreamingChunk(content=content, meta=metadata)) return replies diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 6177ed359..b6dd40896 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -220,6 +220,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert len(response["replies"]) > 0 assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + # check the first response is a function call chat_message = response["replies"][0] assert "function_call" in chat_message.meta assert chat_message.content == {"location": "Berlin", "unit": "celsius"} @@ -231,9 +232,10 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert len(response["replies"]) > 0 assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + # check the second response is not a function call chat_message = response["replies"][0] assert "function_call" not in chat_message.meta - assert chat_message.content + assert isinstance(chat_message.content, str) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") @@ -250,7 +252,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 get_current_weather_func = FunctionDeclaration.from_function( get_current_weather, descriptions={ - "location": "The city and state, e.g. San Francisco, CA", + "location": "The city and state, e.g. San Francisco", "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", }, ) @@ -264,6 +266,11 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) assert streaming_callback_called + # check the first response is a function call + chat_message = response["replies"][0] + assert "function_call" in chat_message.meta + assert chat_message.content == {"location": "Berlin", "unit": "celsius"} + weather = get_current_weather(**response["replies"][0].content) messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] response = gemini_chat.run(messages=messages) @@ -271,6 +278,11 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert len(response["replies"]) > 0 assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + # check the second response is not a function call + chat_message = response["replies"][0] + assert "function_call" not in chat_message.meta + assert isinstance(chat_message.content, str) + @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") def test_past_conversation(): diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 3673ad4c7..e8de0b006 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -229,10 +229,12 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: :param response_body: The response from Vertex AI request. :returns: The extracted responses. """ - replies = [] + replies: List[ChatMessage] = [] for candidate in response_body.candidates: - metadata = candidate.to_dict() for part in candidate.content.parts: + # Remove content from metadata + metadata = part.to_dict() + metadata.pop("content", None) if part._raw_part.text != "": replies.append( ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata) @@ -259,14 +261,13 @@ def _get_stream_response( :param streaming_callback: The handler for the streaming response. :returns: The extracted response with the content of all streaming chunks. """ - replies = [] + replies: List[ChatMessage] = [] + content: Union[str, Dict[str, Any]] = "" - content: Union[str, Dict[Any, Any]] = "" for chunk in stream: - metadata = chunk.to_dict() + metadata = chunk.to_dict() # we store whole chunk as metadata for streaming for candidate in chunk.candidates: for part in candidate.content.parts: - if part._raw_part.text: content = chunk.text replies.append(ChatMessage(content, role=ChatRole.ASSISTANT, name=None, meta=metadata)) @@ -281,7 +282,7 @@ def _get_stream_response( meta=metadata, ) ) - streaming_chunk = StreamingChunk(content=content, meta=chunk.to_dict()) + streaming_chunk = StreamingChunk(content=content, meta=metadata) streaming_callback(streaming_chunk) return replies