Skip to content

Commit

Permalink
fix: make GoogleAI Chat Generator compatible with new ChatMessage; …
Browse files Browse the repository at this point in the history
…small fixes to Cohere tests (#1253)

* draft

* improvements

* small improvemtn

* rm duplication

* simplification
  • Loading branch information
anakin87 authored Dec 19, 2024
1 parent e35d3cb commit 58cb135
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 17 deletions.
2 changes: 1 addition & 1 deletion integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def streaming_chunk(text: str):

@pytest.fixture
def chat_messages():
return [ChatMessage.from_assistant(content="What's the capital of France")]
return [ChatMessage.from_assistant("What's the capital of France")]


class TestCohereChatGenerator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@ def _message_to_part(self, message: ChatMessage) -> Part:
p.function_response.name = message.name
p.function_response.response = message.text
return p
elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL):
p = Part()
p.function_response.name = message.tool_call_result.origin.tool_name
p.function_response.response = message.tool_call_result.result
return p
elif message.is_from(ChatRole.USER):
return self._convert_part(message.text)

Expand All @@ -266,10 +271,17 @@ def _message_to_content(self, message: ChatMessage) -> Content:
part.function_response.response = message.text
elif message.is_from(ChatRole.USER):
part = self._convert_part(message.text)
elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL):
part = Part()
part.function_response.name = message.tool_call_result.origin.tool_name
part.function_response.response = message.tool_call_result.result
else:
msg = f"Unsupported message role {message.role}"
raise ValueError(msg)
role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model"

role = "user"
if message.is_from(ChatRole.ASSISTANT) or message.is_from(ChatRole.SYSTEM):
role = "model"
return Content(parts=[part], role=role)

@component.output_types(replies=List[ChatMessage])
Expand Down Expand Up @@ -335,13 +347,16 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess

for part in candidate.content.parts:
if part.text != "":
replies.append(ChatMessage.from_assistant(content=part.text, meta=candidate_metadata))
replies.append(ChatMessage.from_assistant(part.text, meta=candidate_metadata))
elif part.function_call:
candidate_metadata["function_call"] = part.function_call
new_message = ChatMessage.from_assistant(
content=json.dumps(dict(part.function_call.args)), meta=candidate_metadata
json.dumps(dict(part.function_call.args)), meta=candidate_metadata
)
new_message.name = part.function_call.name
try:
new_message.name = part.function_call.name
except AttributeError:
new_message._name = part.function_call.name
replies.append(new_message)
return replies

Expand All @@ -364,12 +379,15 @@ def _get_stream_response(
for part in candidate["content"]["parts"]:
if "text" in part and part["text"] != "":
content = part["text"]
replies.append(ChatMessage.from_assistant(content=content, meta=metadata))
replies.append(ChatMessage.from_assistant(content, meta=metadata))
elif "function_call" in part and len(part["function_call"]) > 0:
metadata["function_call"] = part["function_call"]
content = json.dumps(dict(part["function_call"]["args"]))
new_message = ChatMessage.from_assistant(content=content, meta=metadata)
new_message.name = part["function_call"]["name"]
new_message = ChatMessage.from_assistant(content, meta=metadata)
try:
new_message.name = part["function_call"]["name"]
except AttributeError:
new_message._name = part["function_call"]["name"]
replies.append(new_message)

streaming_callback(StreamingChunk(content=content, meta=metadata))
Expand Down
18 changes: 9 additions & 9 deletions integrations/google_ai/tests/generators/chat/test_chat_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001

tool = Tool(function_declarations=[get_current_weather_func])
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool])
messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")]
messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
Expand All @@ -227,7 +227,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"}

weather = get_current_weather(**json.loads(chat_message.text))
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
Expand Down Expand Up @@ -260,7 +260,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001

tool = Tool(function_declarations=[get_current_weather_func])
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool], streaming_callback=streaming_callback)
messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")]
messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
Expand All @@ -272,8 +272,8 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
assert "function_call" in chat_message.meta
assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"}

weather = get_current_weather(**json.loads(response["replies"][0].text))
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
weather = get_current_weather(**json.loads(chat_message.text))
messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
Expand All @@ -289,10 +289,10 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
def test_past_conversation():
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro")
messages = [
ChatMessage.from_system(content="You are a knowledageable mathematician."),
ChatMessage.from_user(content="What is 2+2?"),
ChatMessage.from_assistant(content="It's an arithmetic operation."),
ChatMessage.from_user(content="Yeah, but what's the result?"),
ChatMessage.from_system("You are a knowledageable mathematician."),
ChatMessage.from_user("What is 2+2?"),
ChatMessage.from_assistant("It's an arithmetic operation."),
ChatMessage.from_user("Yeah, but what's the result?"),
]
response = gemini_chat.run(messages=messages)
assert "replies" in response
Expand Down

0 comments on commit 58cb135

Please sign in to comment.