From 518a6b9dd56e5f91ecd4be459a9ba72c0c9ddfb2 Mon Sep 17 00:00:00 2001 From: sunilkumardash9 Date: Sat, 23 Dec 2023 13:41:01 +0530 Subject: [PATCH] 1. Adds ChatRole and convert default role to Cohere compliant role 2. Adds a unit test for 1 Signed-off-by: sunilkumardash9 --- .../cohere_haystack/chat/chat_generator.py | 24 +++++++++++++------ .../tests/test_cohere_chat_generator.py | 18 +++++++++----- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/integrations/cohere/src/cohere_haystack/chat/chat_generator.py b/integrations/cohere/src/cohere_haystack/chat/chat_generator.py index 0ac53bd5d..f3178d567 100644 --- a/integrations/cohere/src/cohere_haystack/chat/chat_generator.py +++ b/integrations/cohere/src/cohere_haystack/chat/chat_generator.py @@ -4,7 +4,7 @@ from haystack import component, default_from_dict, default_to_dict from haystack.components.generators.utils import deserialize_callback_handler, serialize_callback_handler -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack.lazy_imports import LazyImport with LazyImport(message="Run 'pip install cohere'") as cohere_import: @@ -119,6 +119,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereChatGenerator": data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) return default_from_dict(cls, data) + def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: + if message.role == ChatRole.USER: + role = "User" + elif message.role == ChatRole.ASSISTANT: + role = "Chatbot" + chat_message = {"user_name": role, "text": message.content} + return chat_message + @component.output_types(replies=List[ChatMessage]) def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): """ @@ -133,16 +141,20 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, """ # update generation kwargs by merging with the generation kwargs passed to the run method generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - message = [message.content for message in messages] + chat_history = [self._message_to_dict(m) for m in messages[:-1]] response = self.client.chat( - message=message[0], model=self.model_name, stream=self.streaming_callback is not None, **generation_kwargs + message=messages[-1].content, + model=self.model_name, + stream=self.streaming_callback is not None, + chat_history=chat_history, + **generation_kwargs, ) if self.streaming_callback: for chunk in response: if chunk.event_type == "text-generation": stream_chunk = self._build_chunk(chunk) self.streaming_callback(stream_chunk) - chat_message = ChatMessage(content=response.texts, role=None, name=None) + chat_message = ChatMessage.from_assistant(content=response.texts) chat_message.metadata.update( { "model": self.model_name, @@ -151,7 +163,6 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, "finish_reason": response.finish_reason, "documents": response.documents, "citations": response.citations, - "chat-history": response.chat_history, } ) else: @@ -178,7 +189,7 @@ def _build_message(self, cohere_response): :return: The ChatMessage. """ content = cohere_response.text - message = ChatMessage(content=content, role=None, name=None) + message = ChatMessage.from_assistant(content=content) message.metadata.update( { "model": self.model_name, @@ -187,7 +198,6 @@ def _build_message(self, cohere_response): "finish_reason": None, "documents": cohere_response.documents, "citations": cohere_response.citations, - "chat-history": cohere_response.chat_history, } ) return message diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 6b84eb9b2..92954df8b 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -4,7 +4,7 @@ import cohere import pytest from haystack.components.generators.utils import default_streaming_callback -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from cohere_haystack.chat.chat_generator import CohereChatGenerator @@ -49,7 +49,7 @@ def streaming_chunk(text: str): @pytest.fixture def chat_messages(): - return [ChatMessage(content="What's the capital of France", role=None, name=None)] + return [ChatMessage(content="What's the capital of France", role=ChatRole.ASSISTANT, name=None)] class TestCohereChatGenerator: @@ -182,6 +182,12 @@ def test_run(self, chat_messages, mock_chat_response): # noqa: ARG002 assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + @pytest.mark.unit + def test_message_to_dict(self, chat_messages): + obj = CohereChatGenerator(api_key="api-key") + dictionary = [obj._message_to_dict(message) for message in chat_messages] + assert dictionary == [{"user_name": "Chatbot", "text": "What's the capital of France"}] + @pytest.mark.unit def test_run_with_params(self, chat_messages, mock_chat_response): component = CohereChatGenerator( @@ -239,7 +245,7 @@ def mock_iter(self): # noqa: ARG001 ) @pytest.mark.integration def test_live_run(self): - chat_messages = [ChatMessage(content="What's the capital of France", role=None, name="", metadata={})] + chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", metadata={})] component = CohereChatGenerator( api_key=os.environ.get("COHERE_API_KEY"), generation_kwargs={"temperature": 0.8} ) @@ -257,7 +263,7 @@ def test_live_run_wrong_model(self, chat_messages): component = CohereChatGenerator( model_name="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY") ) - with pytest.raises(cohere.CohereAPIError, match=r"^Model not found (.+)$"): + with pytest.raises(cohere.CohereAPIError, match="finetuned model something-obviously-wrong is not valid"): component.run(chat_messages) @pytest.mark.skipif( @@ -278,7 +284,7 @@ def __call__(self, chunk: StreamingChunk) -> None: callback = Callback() component = CohereChatGenerator(os.environ.get("COHERE_API_KEY"), streaming_callback=callback) results = component.run( - [ChatMessage(content="What's the capital of France? answer in a word", role=None, name=None)] + [ChatMessage(content="What's the capital of France? answer in a word", role=ChatRole.USER, name=None)] ) assert len(results["replies"]) == 1 @@ -296,7 +302,7 @@ def __call__(self, chunk: StreamingChunk) -> None: ) @pytest.mark.integration def test_live_run_with_connector(self): - chat_messages = [ChatMessage(content="What's the capital of France", role=None, name="", metadata={})] + chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", metadata={})] component = CohereChatGenerator( api_key=os.environ.get("COHERE_API_KEY"), generation_kwargs={"temperature": 0.8} )