Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add chatrole tests and meta for GeminiChatGenerators #1090

Merged
merged 11 commits into from
Sep 24, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -311,17 +311,25 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess
:param response_body: The response from Google AI request.
:returns: The extracted responses.
"""
replies = []
for candidate in response_body.candidates:
replies: List[ChatMessage] = []
metadata = response_body.to_dict()
for idx, candidate in enumerate(response_body.candidates):
candidate_metadata = metadata["candidates"][idx]
candidate_metadata.pop("content", None) # we remove content from the metadata

for part in candidate.content.parts:
if part.text != "":
replies.append(ChatMessage.from_assistant(part.text))
elif part.function_call is not None:
replies.append(
ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=candidate_metadata)
)
elif part.function_call:
candidate_metadata["function_call"] = part.function_call
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=part.function_call.name,
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
meta=candidate_metadata,
)
)
return replies
Expand All @@ -336,11 +344,26 @@ def _get_stream_response(
:param streaming_callback: The handler for the streaming response.
:returns: The extracted response with the content of all streaming chunks.
"""
responses = []
replies: List[ChatMessage] = []
content: Union[str, Dict[str, Any]] = ""
for chunk in stream:
content = chunk.text if len(chunk.parts) > 0 and "text" in chunk.parts[0] else ""
streaming_callback(StreamingChunk(content=content, meta=chunk.to_dict()))
responses.append(content)
metadata = chunk.to_dict() # we store whole chunk as metadata in streaming calls
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
for candidate in chunk.candidates:
for part in candidate.content.parts:
if part.text != "":
content = part.text
replies.append(ChatMessage(content=content, role=ChatRole.ASSISTANT, meta=metadata, name=None))
elif part.function_call is not None:
metadata["function_call"] = part.function_call
content = dict(part.function_call.args.items())
replies.append(
ChatMessage(
content=content,
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
)

combined_response = "".join(responses).lstrip()
return [ChatMessage.from_assistant(content=combined_response)]
streaming_callback(StreamingChunk(content=content, meta=metadata))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the indentation correct here? We're calling streaming_callback once per candidate. Shouldn't it be called once per part? What if there are multiple parts per candidate?

Copy link
Contributor Author

@Amnah199 Amnah199 Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have a valid point. I looked into this, but it's difficult to find any chunk that contains multiple Parts, so for now, it produces the same result. In other generators, the entire chunk is passed to the callback function, as shown here. Because of this, I'm not entirely certain, but we can do it your way.
@anakin87 would you have an input here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. @vblagoje knows better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@silvanocerza I updated based on your suggestion for now which makes more sense, also this PR is blocking my other PR. So I'll just go ahead

return replies
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
60 changes: 47 additions & 13 deletions integrations/google_ai/tests/generators/chat/test_chat_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from google.generativeai import GenerationConfig, GenerativeModel
from google.generativeai.types import FunctionDeclaration, HarmBlockThreshold, HarmCategory, Tool
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses.chat_message import ChatMessage
from haystack.dataclasses.chat_message import ChatMessage, ChatRole

from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator

Expand Down Expand Up @@ -207,22 +207,35 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
get_current_weather_func = FunctionDeclaration.from_function(
get_current_weather,
descriptions={
"location": "The city and state, e.g. San Francisco, CA",
"location": "The city and state, e.g. San Francisco",
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
"unit": "The temperature unit of measurement, e.g. celsius or fahrenheit",
},
)

tool = Tool(function_declarations=[get_current_weather_func])
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool])
messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")]
res = gemini_chat.run(messages=messages)
assert len(res["replies"]) > 0
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

weather = get_current_weather(**res["replies"][0].content)
messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
# check the first response is a function call
chat_message = response["replies"][0]
assert "function_call" in chat_message.meta
assert chat_message.content == {"location": "Berlin", "unit": "celsius"}

res = gemini_chat.run(messages=messages)
assert len(res["replies"]) > 0
weather = get_current_weather(**chat_message.content)
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

# check the second response is not a function call
chat_message = response["replies"][0]
assert "function_call" not in chat_message.meta
assert isinstance(chat_message.content, str)


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
Expand All @@ -239,18 +252,37 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
get_current_weather_func = FunctionDeclaration.from_function(
get_current_weather,
descriptions={
"location": "The city and state, e.g. San Francisco, CA",
"location": "The city and state, e.g. San Francisco",
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
"unit": "The temperature unit of measurement, e.g. celsius or fahrenheit",
},
)

tool = Tool(function_declarations=[get_current_weather_func])
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool], streaming_callback=streaming_callback)
messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")]
res = gemini_chat.run(messages=messages)
assert len(res["replies"]) > 0
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])
assert streaming_callback_called

# check the first response is a function call
chat_message = response["replies"][0]
assert "function_call" in chat_message.meta
assert chat_message.content == {"location": "Berlin", "unit": "celsius"}

weather = get_current_weather(**response["replies"][0].content)
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

# check the second response is not a function call
chat_message = response["replies"][0]
assert "function_call" not in chat_message.meta
assert isinstance(chat_message.content, str)


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
def test_past_conversation():
Expand All @@ -261,5 +293,7 @@ def test_past_conversation():
ChatMessage.from_assistant(content="It's an arithmetic operation."),
ChatMessage.from_user(content="Yeah, but what's the result?"),
]
res = gemini_chat.run(messages=messages)
assert len(res["replies"]) > 0
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,24 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]:
:param response_body: The response from Vertex AI request.
:returns: The extracted responses.
"""
replies = []
replies: List[ChatMessage] = []
for candidate in response_body.candidates:
for part in candidate.content.parts:
# Remove content from metadata
metadata = part.to_dict()
metadata.pop("content", None)
if part._raw_part.text != "":
replies.append(ChatMessage.from_assistant(part.text))
elif part.function_call is not None:
replies.append(
ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata)
)
elif part.function_call:
metadata["function_call"] = part.function_call
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
)
return replies
Expand All @@ -254,11 +261,28 @@ def _get_stream_response(
:param streaming_callback: The handler for the streaming response.
:returns: The extracted response with the content of all streaming chunks.
"""
responses = []
replies: List[ChatMessage] = []
content: Union[str, Dict[str, Any]] = ""

for chunk in stream:
streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict())
metadata = chunk.to_dict() # we store whole chunk as metadata for streaming
for candidate in chunk.candidates:
for part in candidate.content.parts:
if part._raw_part.text:
content = chunk.text
replies.append(ChatMessage(content, role=ChatRole.ASSISTANT, name=None, meta=metadata))
elif part.function_call:
metadata["function_call"] = part.function_call
content = dict(part.function_call.args.items())
replies.append(
ChatMessage(
content=content,
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
)
streaming_chunk = StreamingChunk(content=content, meta=metadata)
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
streaming_callback(streaming_chunk)
responses.append(streaming_chunk.content)

combined_response = "".join(responses).lstrip()
return [ChatMessage.from_assistant(content=combined_response)]
return replies
18 changes: 10 additions & 8 deletions integrations/google_vertex/tests/chat/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from haystack import Pipeline
from haystack.components.builders import ChatPromptBuilder
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk
from vertexai.generative_models import (
Content,
FunctionDeclaration,
Expand Down Expand Up @@ -249,9 +249,12 @@ def test_run(mock_generative_model):
ChatMessage.from_user("What's the capital of France?"),
]
gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None)
gemini.run(messages=messages)
response = gemini.run(messages=messages)

mock_model.send_message.assert_called_once()
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
Expand All @@ -260,25 +263,24 @@ def test_run_with_streaming_callback(mock_generative_model):
mock_responses = iter(
[MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text="Second part")]
)

mock_model.send_message.return_value = mock_responses
mock_model.start_chat.return_value = mock_model
mock_generative_model.return_value = mock_model

streaming_callback_called = []

def streaming_callback(chunk: StreamingChunk) -> None:
streaming_callback_called.append(chunk.content)
def streaming_callback(_chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True

gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback)
messages = [
ChatMessage.from_system("You are a helpful assistant"),
ChatMessage.from_user("What's the capital of France?"),
]
gemini.run(messages=messages)

response = gemini.run(messages=messages)
mock_model.send_message.assert_called_once()
assert streaming_callback_called == ["First part", "Second part"]
assert "replies" in response


def test_serialization_deserialization_pipeline():
Expand Down