Skip to content

Commit

Permalink
1. Adds ChatRole and convert default role to Cohere compliant role
Browse files Browse the repository at this point in the history
2. Adds a unit test for 1

Signed-off-by: sunilkumardash9 <[email protected]>
  • Loading branch information
sunilkumardash9 committed Dec 23, 2023
1 parent 33f218e commit 518a6b9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
24 changes: 17 additions & 7 deletions integrations/cohere/src/cohere_haystack/chat/chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.utils import deserialize_callback_handler, serialize_callback_handler
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk
from haystack.lazy_imports import LazyImport

with LazyImport(message="Run 'pip install cohere'") as cohere_import:
Expand Down Expand Up @@ -119,6 +119,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereChatGenerator":
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
return default_from_dict(cls, data)

def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]:
if message.role == ChatRole.USER:
role = "User"
elif message.role == ChatRole.ASSISTANT:
role = "Chatbot"
chat_message = {"user_name": role, "text": message.content}
return chat_message

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Expand All @@ -133,16 +141,20 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
"""
# update generation kwargs by merging with the generation kwargs passed to the run method
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
message = [message.content for message in messages]
chat_history = [self._message_to_dict(m) for m in messages[:-1]]
response = self.client.chat(
message=message[0], model=self.model_name, stream=self.streaming_callback is not None, **generation_kwargs
message=messages[-1].content,
model=self.model_name,
stream=self.streaming_callback is not None,
chat_history=chat_history,
**generation_kwargs,
)
if self.streaming_callback:
for chunk in response:
if chunk.event_type == "text-generation":
stream_chunk = self._build_chunk(chunk)
self.streaming_callback(stream_chunk)
chat_message = ChatMessage(content=response.texts, role=None, name=None)
chat_message = ChatMessage.from_assistant(content=response.texts)
chat_message.metadata.update(
{
"model": self.model_name,
Expand All @@ -151,7 +163,6 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
"finish_reason": response.finish_reason,
"documents": response.documents,
"citations": response.citations,
"chat-history": response.chat_history,
}
)
else:
Expand All @@ -178,7 +189,7 @@ def _build_message(self, cohere_response):
:return: The ChatMessage.
"""
content = cohere_response.text
message = ChatMessage(content=content, role=None, name=None)
message = ChatMessage.from_assistant(content=content)
message.metadata.update(
{
"model": self.model_name,
Expand All @@ -187,7 +198,6 @@ def _build_message(self, cohere_response):
"finish_reason": None,
"documents": cohere_response.documents,
"citations": cohere_response.citations,
"chat-history": cohere_response.chat_history,
}
)
return message
18 changes: 12 additions & 6 deletions integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import cohere
import pytest
from haystack.components.generators.utils import default_streaming_callback
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk

from cohere_haystack.chat.chat_generator import CohereChatGenerator

Expand Down Expand Up @@ -49,7 +49,7 @@ def streaming_chunk(text: str):

@pytest.fixture
def chat_messages():
return [ChatMessage(content="What's the capital of France", role=None, name=None)]
return [ChatMessage(content="What's the capital of France", role=ChatRole.ASSISTANT, name=None)]


class TestCohereChatGenerator:
Expand Down Expand Up @@ -182,6 +182,12 @@ def test_run(self, chat_messages, mock_chat_response): # noqa: ARG002
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]

@pytest.mark.unit
def test_message_to_dict(self, chat_messages):
obj = CohereChatGenerator(api_key="api-key")
dictionary = [obj._message_to_dict(message) for message in chat_messages]
assert dictionary == [{"user_name": "Chatbot", "text": "What's the capital of France"}]

@pytest.mark.unit
def test_run_with_params(self, chat_messages, mock_chat_response):
component = CohereChatGenerator(
Expand Down Expand Up @@ -239,7 +245,7 @@ def mock_iter(self): # noqa: ARG001
)
@pytest.mark.integration
def test_live_run(self):
chat_messages = [ChatMessage(content="What's the capital of France", role=None, name="", metadata={})]
chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", metadata={})]
component = CohereChatGenerator(
api_key=os.environ.get("COHERE_API_KEY"), generation_kwargs={"temperature": 0.8}
)
Expand All @@ -257,7 +263,7 @@ def test_live_run_wrong_model(self, chat_messages):
component = CohereChatGenerator(
model_name="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")
)
with pytest.raises(cohere.CohereAPIError, match=r"^Model not found (.+)$"):
with pytest.raises(cohere.CohereAPIError, match="finetuned model something-obviously-wrong is not valid"):
component.run(chat_messages)

@pytest.mark.skipif(
Expand All @@ -278,7 +284,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
callback = Callback()
component = CohereChatGenerator(os.environ.get("COHERE_API_KEY"), streaming_callback=callback)
results = component.run(
[ChatMessage(content="What's the capital of France? answer in a word", role=None, name=None)]
[ChatMessage(content="What's the capital of France? answer in a word", role=ChatRole.USER, name=None)]
)

assert len(results["replies"]) == 1
Expand All @@ -296,7 +302,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
)
@pytest.mark.integration
def test_live_run_with_connector(self):
chat_messages = [ChatMessage(content="What's the capital of France", role=None, name="", metadata={})]
chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", metadata={})]
component = CohereChatGenerator(
api_key=os.environ.get("COHERE_API_KEY"), generation_kwargs={"temperature": 0.8}
)
Expand Down

0 comments on commit 518a6b9

Please sign in to comment.