Skip to content

Commit

Permalink
use class methods to create ChatMessage (deepset-ai#1222)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Nov 28, 2024
1 parent 5de49be commit 94a29cb
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 68 deletions.
4 changes: 1 addition & 3 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,8 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):

generator.model_adapter.get_responses = MagicMock(
return_value=[
ChatMessage(
ChatMessage.from_assistant(
content="Some text",
role=ChatRole.ASSISTANT,
name=None,
meta={
"model": "claude-3-sonnet-20240229",
"index": 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Callable, Dict, List, Optional

from haystack import component
from haystack.dataclasses import ChatMessage, ChatRole
from haystack.dataclasses import ChatMessage
from haystack.utils import Secret

from .chat.chat_generator import CohereChatGenerator
Expand Down Expand Up @@ -64,7 +64,7 @@ def run(self, prompt: str):
- `replies`: A list of replies generated by the model.
- `meta`: Information about the request.
"""
chat_message = ChatMessage(content=prompt, role=ChatRole.USER, name="", meta={})
chat_message = ChatMessage.from_user(prompt)
# Note we have to call super() like this because of the way components are dynamically built with the decorator
results = super(CohereGenerator, self).run([chat_message]) # noqa
return {"replies": [results["replies"][0].content], "meta": [results["replies"][0].meta]}
12 changes: 5 additions & 7 deletions integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def streaming_chunk(text: str):

@pytest.fixture
def chat_messages():
return [ChatMessage(content="What's the capital of France", role=ChatRole.ASSISTANT, name=None)]
return [ChatMessage.from_assistant(content="What's the capital of France")]


class TestCohereChatGenerator:
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_message_to_dict(self, chat_messages):
)
@pytest.mark.integration
def test_live_run(self):
chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", meta={})]
chat_messages = [ChatMessage.from_user(content="What's the capital of France")]
component = CohereChatGenerator(generation_kwargs={"temperature": 0.8})
results = component.run(chat_messages)
assert len(results["replies"]) == 1
Expand Down Expand Up @@ -201,9 +201,7 @@ def __call__(self, chunk: StreamingChunk) -> None:

callback = Callback()
component = CohereChatGenerator(streaming_callback=callback)
results = component.run(
[ChatMessage(content="What's the capital of France? answer in a word", role=ChatRole.USER, name=None)]
)
results = component.run([ChatMessage.from_user(content="What's the capital of France? answer in a word")])

assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
Expand All @@ -224,7 +222,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
)
@pytest.mark.integration
def test_live_run_with_connector(self):
chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", meta={})]
chat_messages = [ChatMessage.from_user(content="What's the capital of France")]
component = CohereChatGenerator(generation_kwargs={"temperature": 0.8})
results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]})
assert len(results["replies"]) == 1
Expand All @@ -249,7 +247,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
self.responses += chunk.content if chunk.content else ""

callback = Callback()
chat_messages = [ChatMessage(content="What's the capital of France? answer in a word", role=None, name=None)]
chat_messages = [ChatMessage.from_user(content="What's the capital of France? answer in a word")]
component = CohereChatGenerator(streaming_callback=callback)
results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,19 +334,14 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess

for part in candidate.content.parts:
if part.text != "":
replies.append(
ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=candidate_metadata)
)
replies.append(ChatMessage.from_assistant(content=part.text, meta=candidate_metadata))
elif part.function_call:
candidate_metadata["function_call"] = part.function_call
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=candidate_metadata,
)
new_message = ChatMessage.from_assistant(
content=dict(part.function_call.args.items()), meta=candidate_metadata
)
new_message.name = part.function_call.name
replies.append(new_message)
return replies

def _get_stream_response(
Expand All @@ -368,18 +363,13 @@ def _get_stream_response(
for part in candidate["content"]["parts"]:
if "text" in part and part["text"] != "":
content = part["text"]
replies.append(ChatMessage(content=content, role=ChatRole.ASSISTANT, meta=metadata, name=None))
replies.append(ChatMessage.from_assistant(content=content, meta=metadata))
elif "function_call" in part and len(part["function_call"]) > 0:
metadata["function_call"] = part["function_call"]
content = part["function_call"]["args"]
replies.append(
ChatMessage(
content=content,
role=ChatRole.ASSISTANT,
name=part["function_call"]["name"],
meta=metadata,
)
)
new_message = ChatMessage.from_assistant(content=content, meta=metadata)
new_message.name = part["function_call"]["name"]
replies.append(new_message)

streaming_callback(StreamingChunk(content=content, meta=metadata))
return replies
Original file line number Diff line number Diff line change
Expand Up @@ -279,19 +279,14 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]:
# Remove content from metadata
metadata.pop("content", None)
if part._raw_part.text != "":
replies.append(
ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata)
)
replies.append(ChatMessage.from_assistant(content=part._raw_part.text, meta=metadata))
elif part.function_call:
metadata["function_call"] = part.function_call
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
new_message = ChatMessage.from_assistant(
content=dict(part.function_call.args.items()), meta=metadata
)
new_message.name = part.function_call.name
replies.append(new_message)
return replies

def _get_stream_response(
Expand All @@ -313,18 +308,13 @@ def _get_stream_response(
for part in candidate.content.parts:
if part._raw_part.text:
content = chunk.text
replies.append(ChatMessage(content, role=ChatRole.ASSISTANT, name=None, meta=metadata))
replies.append(ChatMessage.from_assistant(content, meta=metadata))
elif part.function_call:
metadata["function_call"] = part.function_call
content = dict(part.function_call.args.items())
replies.append(
ChatMessage(
content=content,
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
)
new_message = ChatMessage.from_assistant(content, meta=metadata)
new_message.name = part.function_call.name
replies.append(new_message)
streaming_callback(StreamingChunk(content=content, meta=metadata))

return replies
27 changes: 9 additions & 18 deletions integrations/ollama/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, ChatRole
from haystack.dataclasses import ChatMessage
from ollama._types import ChatResponse, ResponseError

from haystack_integrations.components.generators.ollama import OllamaChatGenerator
Expand Down Expand Up @@ -128,16 +128,12 @@ def test_run_with_chat_history(self):
chat_generator = OllamaChatGenerator()

chat_history = [
{"role": "user", "content": "What is the largest city in the United Kingdom by population?"},
{"role": "assistant", "content": "London is the largest city in the United Kingdom by population"},
{"role": "user", "content": "And what is the second largest?"},
ChatMessage.from_user("What is the largest city in the United Kingdom by population?"),
ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"),
ChatMessage.from_user("And what is the second largest?"),
]

chat_messages = [
ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None)
for message in chat_history
]
response = chat_generator.run(chat_messages)
response = chat_generator.run(chat_history)

assert isinstance(response, dict)
assert isinstance(response["replies"], list)
Expand All @@ -159,17 +155,12 @@ def test_run_with_streaming(self):
chat_generator = OllamaChatGenerator(streaming_callback=streaming_callback)

chat_history = [
{"role": "user", "content": "What is the largest city in the United Kingdom by population?"},
{"role": "assistant", "content": "London is the largest city in the United Kingdom by population"},
{"role": "user", "content": "And what is the second largest?"},
]

chat_messages = [
ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None)
for message in chat_history
ChatMessage.from_user("What is the largest city in the United Kingdom by population?"),
ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"),
ChatMessage.from_user("And what is the second largest?"),
]

response = chat_generator.run(chat_messages)
response = chat_generator.run(chat_history)

streaming_callback.assert_called()

Expand Down

0 comments on commit 94a29cb

Please sign in to comment.