Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make GoogleAI Chat Generator compatible with new ChatMessage; small fixes to Cohere tests #1253

Merged
merged 5 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def streaming_chunk(text: str):

@pytest.fixture
def chat_messages():
return [ChatMessage.from_assistant(content="What's the capital of France")]
return [ChatMessage.from_assistant("What's the capital of France")]


class TestCohereChatGenerator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@ def _message_to_part(self, message: ChatMessage) -> Part:
p.function_response.name = message.name
p.function_response.response = message.text
return p
elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL):
p = Part()
p.function_response.name = message.tool_call_result.origin.tool_name
p.function_response.response = message.tool_call_result.result
return p
elif message.is_from(ChatRole.USER):
return self._convert_part(message.text)

Expand All @@ -266,10 +271,17 @@ def _message_to_content(self, message: ChatMessage) -> Content:
part.function_response.response = message.text
elif message.is_from(ChatRole.USER):
part = self._convert_part(message.text)
elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Even if I don't particularly like this check) this is done to check if the message has the tool role - only present in Haystack main (to be released in 2.9.0)

part = Part()
part.function_response.name = message.tool_call_result.origin.tool_name
part.function_response.response = message.tool_call_result.result
else:
msg = f"Unsupported message role {message.role}"
raise ValueError(msg)
role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model"

role = "user"
if message.is_from(ChatRole.ASSISTANT) or message.is_from(ChatRole.SYSTEM):
role = "model"
return Content(parts=[part], role=role)

@component.output_types(replies=List[ChatMessage])
Expand Down Expand Up @@ -335,13 +347,16 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess

for part in candidate.content.parts:
if part.text != "":
replies.append(ChatMessage.from_assistant(content=part.text, meta=candidate_metadata))
replies.append(ChatMessage.from_assistant(part.text, meta=candidate_metadata))
elif part.function_call:
candidate_metadata["function_call"] = part.function_call
new_message = ChatMessage.from_assistant(
content=json.dumps(dict(part.function_call.args)), meta=candidate_metadata
json.dumps(dict(part.function_call.args)), meta=candidate_metadata
)
new_message.name = part.function_call.name
try:
new_message.name = part.function_call.name
except AttributeError:
new_message._name = part.function_call.name
replies.append(new_message)
return replies

Expand All @@ -364,12 +379,15 @@ def _get_stream_response(
for part in candidate["content"]["parts"]:
if "text" in part and part["text"] != "":
content = part["text"]
replies.append(ChatMessage.from_assistant(content=content, meta=metadata))
replies.append(ChatMessage.from_assistant(content, meta=metadata))
elif "function_call" in part and len(part["function_call"]) > 0:
metadata["function_call"] = part["function_call"]
content = json.dumps(dict(part["function_call"]["args"]))
new_message = ChatMessage.from_assistant(content=content, meta=metadata)
new_message.name = part["function_call"]["name"]
new_message = ChatMessage.from_assistant(content, meta=metadata)
try:
new_message.name = part["function_call"]["name"]
except AttributeError:
new_message._name = part["function_call"]["name"]
replies.append(new_message)

streaming_callback(StreamingChunk(content=content, meta=metadata))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001

tool = Tool(function_declarations=[get_current_weather_func])
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool])
messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")]
messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
Expand All @@ -227,7 +227,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"}

weather = get_current_weather(**json.loads(chat_message.text))
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
Expand Down Expand Up @@ -260,7 +260,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001

tool = Tool(function_declarations=[get_current_weather_func])
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool], streaming_callback=streaming_callback)
messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")]
messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
Expand All @@ -272,8 +272,8 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
assert "function_call" in chat_message.meta
assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"}

weather = get_current_weather(**json.loads(response["replies"][0].text))
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
weather = get_current_weather(**json.loads(chat_message.text))
messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
Expand All @@ -289,10 +289,10 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
def test_past_conversation():
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro")
messages = [
ChatMessage.from_system(content="You are a knowledageable mathematician."),
ChatMessage.from_user(content="What is 2+2?"),
ChatMessage.from_assistant(content="It's an arithmetic operation."),
ChatMessage.from_user(content="Yeah, but what's the result?"),
ChatMessage.from_system("You are a knowledageable mathematician."),
ChatMessage.from_user("What is 2+2?"),
ChatMessage.from_assistant("It's an arithmetic operation."),
ChatMessage.from_user("Yeah, but what's the result?"),
]
response = gemini_chat.run(messages=messages)
assert "replies" in response
Expand Down
Loading