Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add chatrole tests and meta for GeminiChatGenerators #1090

Merged
merged 11 commits into from
Sep 24, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,20 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess
:returns: The extracted responses.
"""
replies = []
metadata = response_body.to_dict()
[candidate.pop("content", None) for candidate in metadata["candidates"]]
for candidate in response_body.candidates:
for part in candidate.content.parts:
if part.text != "":
replies.append(ChatMessage.from_assistant(part.text))
replies.append(ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata))
elif part.function_call is not None:
metadata["function_call"] = part.function_call
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=part.function_call.name,
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
meta=metadata,
)
)
return replies
Expand All @@ -336,11 +340,26 @@ def _get_stream_response(
:param streaming_callback: The handler for the streaming response.
:returns: The extracted response with the content of all streaming chunks.
"""
responses = []
replies: Union[List[str], List[ChatMessage]] = []
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved

for chunk in stream:
content = chunk.text if len(chunk.parts) > 0 and "text" in chunk.parts[0] else ""
streaming_callback(StreamingChunk(content=content, meta=chunk.to_dict()))
responses.append(content)
metadata = chunk.to_dict()
for candidate in chunk.candidates:
for part in candidate.content.parts:
if part.text != "":
replies.append(
ChatMessage(content=part.text, role=ChatRole.ASSISTANT, meta=metadata, name=None)
)
elif part.function_call is not None:
metadata["function_call"] = part.function_call
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
)

combined_response = "".join(responses).lstrip()
return [ChatMessage.from_assistant(content=combined_response)]
streaming_callback(StreamingChunk(content=part.text, meta=chunk.to_dict()))
return replies
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
46 changes: 34 additions & 12 deletions integrations/google_ai/tests/generators/chat/test_chat_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from google.generativeai import GenerationConfig, GenerativeModel
from google.generativeai.types import FunctionDeclaration, HarmBlockThreshold, HarmCategory, Tool
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses.chat_message import ChatMessage
from haystack.dataclasses.chat_message import ChatMessage, ChatRole

from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator

Expand Down Expand Up @@ -207,22 +207,33 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
get_current_weather_func = FunctionDeclaration.from_function(
get_current_weather,
descriptions={
"location": "The city and state, e.g. San Francisco, CA",
"location": "The city and state, e.g. San Francisco",
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
"unit": "The temperature unit of measurement, e.g. celsius or fahrenheit",
},
)

tool = Tool(function_declarations=[get_current_weather_func])
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool])
messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")]
res = gemini_chat.run(messages=messages)
assert len(res["replies"]) > 0
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

chat_message = response["replies"][0]
assert "function_call" in chat_message.meta
assert chat_message.content == {"location": "Berlin", "unit": "celsius"}

weather = get_current_weather(**res["replies"][0].content)
messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
weather = get_current_weather(**chat_message.content)
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

res = gemini_chat.run(messages=messages)
assert len(res["replies"]) > 0
chat_message = response["replies"][0]
assert "function_call" not in chat_message.meta
assert chat_message.content


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
Expand All @@ -247,10 +258,19 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
tool = Tool(function_declarations=[get_current_weather_func])
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool], streaming_callback=streaming_callback)
messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")]
res = gemini_chat.run(messages=messages)
assert len(res["replies"]) > 0
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])
assert streaming_callback_called

weather = get_current_weather(**response["replies"][0].content)
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
def test_past_conversation():
Expand All @@ -261,5 +281,7 @@ def test_past_conversation():
ChatMessage.from_assistant(content="It's an arithmetic operation."),
ChatMessage.from_user(content="Yeah, but what's the result?"),
]
res = gemini_chat.run(messages=messages)
assert len(res["replies"]) > 0
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])
Original file line number Diff line number Diff line change
Expand Up @@ -231,15 +231,20 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]:
"""
replies = []
for candidate in response_body.candidates:
metadata = candidate.to_dict()
for part in candidate.content.parts:
if part._raw_part.text != "":
replies.append(ChatMessage.from_assistant(part.text))
elif part.function_call is not None:
replies.append(
ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata)
)
elif part.function_call:
metadata["function_call"] = part.function_call
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
)
return replies
Expand All @@ -254,11 +259,29 @@ def _get_stream_response(
:param streaming_callback: The handler for the streaming response.
:returns: The extracted response with the content of all streaming chunks.
"""
responses = []
replies = []

content: Union[str, Dict[Any, Any]] = ""
for chunk in stream:
streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict())
metadata = chunk.to_dict()
for candidate in chunk.candidates:
for part in candidate.content.parts:

if part._raw_part.text:
content = chunk.text
replies.append(ChatMessage(content, role=ChatRole.ASSISTANT, name=None, meta=metadata))
elif part.function_call:
metadata["function_call"] = part.function_call
content = dict(part.function_call.args.items())
replies.append(
ChatMessage(
content=content,
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
)
streaming_chunk = StreamingChunk(content=content, meta=chunk.to_dict())
streaming_callback(streaming_chunk)
responses.append(streaming_chunk.content)

combined_response = "".join(responses).lstrip()
return [ChatMessage.from_assistant(content=combined_response)]
return replies
18 changes: 10 additions & 8 deletions integrations/google_vertex/tests/chat/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from haystack import Pipeline
from haystack.components.builders import ChatPromptBuilder
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk
from vertexai.generative_models import (
Content,
FunctionDeclaration,
Expand Down Expand Up @@ -249,9 +249,12 @@ def test_run(mock_generative_model):
ChatMessage.from_user("What's the capital of France?"),
]
gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None)
gemini.run(messages=messages)
response = gemini.run(messages=messages)

mock_model.send_message.assert_called_once()
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
Expand All @@ -260,25 +263,24 @@ def test_run_with_streaming_callback(mock_generative_model):
mock_responses = iter(
[MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text="Second part")]
)

mock_model.send_message.return_value = mock_responses
mock_model.start_chat.return_value = mock_model
mock_generative_model.return_value = mock_model

streaming_callback_called = []

def streaming_callback(chunk: StreamingChunk) -> None:
streaming_callback_called.append(chunk.content)
def streaming_callback(_chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True

gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback)
messages = [
ChatMessage.from_system("You are a helpful assistant"),
ChatMessage.from_user("What's the capital of France?"),
]
gemini.run(messages=messages)

response = gemini.run(messages=messages)
mock_model.send_message.assert_called_once()
assert streaming_callback_called == ["First part", "Second part"]
assert "replies" in response


def test_serialization_deserialization_pipeline():
Expand Down