Skip to content

Commit

Permalink
fix Gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Dec 10, 2024
1 parent 7a1297b commit 324a7fa
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -36,12 +37,12 @@ class GoogleAIGeminiChatGenerator:
messages = [ChatMessage.from_user("What is the most interesting thing you know?")]
res = gemini_chat.run(messages=messages)
for reply in res["replies"]:
print(reply.content)
print(reply.text)
messages += res["replies"] + [ChatMessage.from_user("Tell me more about it")]
res = gemini_chat.run(messages=messages)
for reply in res["replies"]:
print(reply.content)
print(reply.text)
```
Expand Down Expand Up @@ -85,14 +86,14 @@ def get_current_weather(location: str, unit: str = "celsius") -> str:
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", api_key=Secret.from_token("<MY_API_KEY>"),
tools=[tool])
messages = [ChatMessage.from_user(content = "What is the temperature in celsius in Berlin?")]
messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")]
res = gemini_chat.run(messages=messages)
weather = get_current_weather(**res["replies"][0].content)
weather = get_current_weather(**json.loads(res["replies"][0].text))
messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
res = gemini_chat.run(messages=messages)
for reply in res["replies"]:
print(reply.content)
print(reply.text)
```
"""

Expand Down Expand Up @@ -234,37 +235,37 @@ def _message_to_part(self, message: ChatMessage) -> Part:
p = Part()
p.function_call.name = message.name
p.function_call.args = {}
for k, v in message.content.items():
for k, v in json.loads(message.text).items():
p.function_call.args[k] = v
return p
elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}:
p = Part()
p.text = message.content
p.text = message.text
return p
elif message.role == ChatRole.FUNCTION:
p = Part()
p.function_response.name = message.name
p.function_response.response = message.content
p.function_response.response = message.text
return p
elif message.role == ChatRole.USER:
return self._convert_part(message.content)
return self._convert_part(message.text)

def _message_to_content(self, message: ChatMessage) -> Content:
if message.role == ChatRole.ASSISTANT and message.name:
part = Part()
part.function_call.name = message.name
part.function_call.args = {}
for k, v in message.content.items():
for k, v in json.loads(message.text).items():
part.function_call.args[k] = v
elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}:
part = Part()
part.text = message.content
part.text = message.text
elif message.role == ChatRole.FUNCTION:
part = Part()
part.function_response.name = message.name
part.function_response.response = message.content
part.function_response.response = message.text
elif message.role == ChatRole.USER:
part = self._convert_part(message.content)
part = self._convert_part(message.text)
else:
msg = f"Unsupported message role {message.role}"
raise ValueError(msg)
Expand Down Expand Up @@ -338,7 +339,7 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess
elif part.function_call:
candidate_metadata["function_call"] = part.function_call
new_message = ChatMessage.from_assistant(
content=dict(part.function_call.args.items()), meta=candidate_metadata
content=json.dumps(dict(part.function_call.args)), meta=candidate_metadata
)
new_message.name = part.function_call.name
replies.append(new_message)
Expand Down Expand Up @@ -366,7 +367,7 @@ def _get_stream_response(
replies.append(ChatMessage.from_assistant(content=content, meta=metadata))
elif "function_call" in part and len(part["function_call"]) > 0:
metadata["function_call"] = part["function_call"]
content = part["function_call"]["args"]
content = json.dumps(dict(part["function_call"]["args"]))
new_message = ChatMessage.from_assistant(content=content, meta=metadata)
new_message.name = part["function_call"]["name"]
replies.append(new_message)
Expand Down
13 changes: 7 additions & 6 deletions integrations/google_ai/tests/generators/chat/test_chat_gemini.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from unittest.mock import patch

Expand Down Expand Up @@ -223,9 +224,9 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
# check the first response is a function call
chat_message = response["replies"][0]
assert "function_call" in chat_message.meta
assert chat_message.content == {"location": "Berlin", "unit": "celsius"}
assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"}

weather = get_current_weather(**chat_message.content)
weather = get_current_weather(**json.loads(chat_message.text))
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
Expand All @@ -235,7 +236,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
# check the second response is not a function call
chat_message = response["replies"][0]
assert "function_call" not in chat_message.meta
assert isinstance(chat_message.content, str)
assert isinstance(chat_message.text, str)


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
Expand Down Expand Up @@ -269,9 +270,9 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
# check the first response is a function call
chat_message = response["replies"][0]
assert "function_call" in chat_message.meta
assert chat_message.content == {"location": "Berlin", "unit": "celsius"}
assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"}

weather = get_current_weather(**response["replies"][0].content)
weather = get_current_weather(**json.loads(response["replies"][0].text))
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
Expand All @@ -281,7 +282,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
# check the second response is not a function call
chat_message = response["replies"][0]
assert "function_call" not in chat_message.meta
assert isinstance(chat_message.content, str)
assert isinstance(chat_message.text, str)


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
Expand Down

0 comments on commit 324a7fa

Please sign in to comment.