Skip to content

Commit

Permalink
Fix errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Sep 20, 2024
1 parent 1f53cd9 commit 4ee566d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -311,21 +311,25 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess
:param response_body: The response from Google AI request.
:returns: The extracted responses.
"""
replies = []
replies: List[ChatMessage] = []
metadata = response_body.to_dict()
[candidate.pop("content", None) for candidate in metadata["candidates"]]
for candidate in response_body.candidates:
for idx, candidate in enumerate(response_body.candidates):
candidate_metadata = metadata["candidates"][idx]
candidate_metadata.pop("content", None) # we remove content from the metadata

for part in candidate.content.parts:
if part.text != "":
replies.append(ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata))
elif part.function_call is not None:
metadata["function_call"] = part.function_call
replies.append(
ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=candidate_metadata)
)
elif part.function_call:
candidate_metadata["function_call"] = part.function_call
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
meta=candidate_metadata,
)
)
return replies
Expand All @@ -340,26 +344,26 @@ def _get_stream_response(
:param streaming_callback: The handler for the streaming response.
:returns: The extracted response with the content of all streaming chunks.
"""
replies: Union[List[str], List[ChatMessage]] = []

replies: List[ChatMessage] = []
content: Union[str, Dict[str, Any]] = ""
for chunk in stream:
metadata = chunk.to_dict()
metadata = chunk.to_dict() # we store whole chunk as metadata in streaming calls
for candidate in chunk.candidates:
for part in candidate.content.parts:
if part.text != "":
replies.append(
ChatMessage(content=part.text, role=ChatRole.ASSISTANT, meta=metadata, name=None)
)
content = part.text
replies.append(ChatMessage(content=content, role=ChatRole.ASSISTANT, meta=metadata, name=None))
elif part.function_call is not None:
metadata["function_call"] = part.function_call
content = dict(part.function_call.args.items())
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
content=content,
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
)

streaming_callback(StreamingChunk(content=part.text, meta=chunk.to_dict()))
streaming_callback(StreamingChunk(content=content, meta=metadata))
return replies
16 changes: 14 additions & 2 deletions integrations/google_ai/tests/generators/chat/test_chat_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

# check the first response is a function call
chat_message = response["replies"][0]
assert "function_call" in chat_message.meta
assert chat_message.content == {"location": "Berlin", "unit": "celsius"}
Expand All @@ -231,9 +232,10 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

# check the second response is not a function call
chat_message = response["replies"][0]
assert "function_call" not in chat_message.meta
assert chat_message.content
assert isinstance(chat_message.content, str)


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
Expand All @@ -250,7 +252,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
get_current_weather_func = FunctionDeclaration.from_function(
get_current_weather,
descriptions={
"location": "The city and state, e.g. San Francisco, CA",
"location": "The city and state, e.g. San Francisco",
"unit": "The temperature unit of measurement, e.g. celsius or fahrenheit",
},
)
Expand All @@ -264,13 +266,23 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])
assert streaming_callback_called

# check the first response is a function call
chat_message = response["replies"][0]
assert "function_call" in chat_message.meta
assert chat_message.content == {"location": "Berlin", "unit": "celsius"}

weather = get_current_weather(**response["replies"][0].content)
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

# check the second response is not a function call
chat_message = response["replies"][0]
assert "function_call" not in chat_message.meta
assert isinstance(chat_message.content, str)


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
def test_past_conversation():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,12 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]:
:param response_body: The response from Vertex AI request.
:returns: The extracted responses.
"""
replies = []
replies: List[ChatMessage] = []
for candidate in response_body.candidates:
metadata = candidate.to_dict()
for part in candidate.content.parts:
# Remove content from metadata
metadata = part.to_dict()
metadata.pop("content", None)
if part._raw_part.text != "":
replies.append(
ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata)
Expand All @@ -259,14 +261,13 @@ def _get_stream_response(
:param streaming_callback: The handler for the streaming response.
:returns: The extracted response with the content of all streaming chunks.
"""
replies = []
replies: List[ChatMessage] = []
content: Union[str, Dict[str, Any]] = ""

content: Union[str, Dict[Any, Any]] = ""
for chunk in stream:
metadata = chunk.to_dict()
metadata = chunk.to_dict() # we store whole chunk as metadata for streaming
for candidate in chunk.candidates:
for part in candidate.content.parts:

if part._raw_part.text:
content = chunk.text
replies.append(ChatMessage(content, role=ChatRole.ASSISTANT, name=None, meta=metadata))
Expand All @@ -281,7 +282,7 @@ def _get_stream_response(
meta=metadata,
)
)
streaming_chunk = StreamingChunk(content=content, meta=chunk.to_dict())
streaming_chunk = StreamingChunk(content=content, meta=metadata)
streaming_callback(streaming_chunk)

return replies

0 comments on commit 4ee566d

Please sign in to comment.