From 6f899422b2cf4859ed1652958f8fc95c9ef2e52a Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 25 Jan 2024 14:49:49 +0100 Subject: [PATCH 1/6] Place both generators under cohere namespace --- .../components/generators/cohere/__init__.py | 4 +++- .../components/generators/cohere/chat/__init__.py | 6 ------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py index c36f982df..43f8201fe 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py @@ -2,5 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 from .generator import CohereGenerator +from .chat.chat_generator import CohereChatGenerator + +__all__ = ["CohereGenerator", "CohereChatGenerator"] -__all__ = ["CohereGenerator"] diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py index dc14c9c1c..e69de29bb 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -from .chat_generator import CohereChatGenerator - -__all__ = ["CohereChatGenerator"] From 9b83ce69c5d4930f74c55da2e22275f398b21256 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 25 Jan 2024 15:35:11 +0100 Subject: [PATCH 2/6] Fix issues found in pre-release checks --- .../generators/cohere/chat/chat_generator.py | 7 ++----- .../components/generators/cohere/generator.py | 17 +++++++++++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 0ff29ce14..c632bed83 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -12,6 +12,7 @@ logger = logging.getLogger(__name__) +@component class CohereChatGenerator: """Enables text generation using Cohere's chat endpoint. This component is designed to inference Cohere's chat models. @@ -123,10 +124,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereChatGenerator": return default_from_dict(cls, data) def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: - if message.role == ChatRole.USER: - role = "User" - elif message.role == ChatRole.ASSISTANT: - role = "Chatbot" + role = "User" if message.role == ChatRole.USER else "Chatbot" chat_message = {"user_name": role, "text": message.content} return chat_message @@ -179,7 +177,6 @@ def _build_chunk(self, chunk) -> StreamingChunk: :param choice: The choice returned by the OpenAI API. :return: The StreamingChunk. """ - # if chunk.event_type == "text-generation": chat_message = StreamingChunk(content=chunk.text, meta={"index": chunk.index, "event_type": chunk.event_type}) return chat_message diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index 7bca3ed9f..31a86a948 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -6,10 +6,10 @@ import sys from typing import Any, Callable, Dict, List, Optional, cast -from haystack import DeserializationError, component, default_from_dict, default_to_dict - from cohere import COHERE_API_URL, Client from cohere.responses import Generations +from haystack import DeserializationError, component, default_from_dict, default_to_dict +from haystack.dataclasses import StreamingChunk logger = logging.getLogger(__name__) @@ -148,8 +148,8 @@ def run(self, prompt: str): if self.streaming_callback: metadata_dict: Dict[str, Any] = {} for chunk in response: - self.streaming_callback(chunk) - metadata_dict["index"] = chunk.index + stream_chunk = self._build_chunk(chunk) + self.streaming_callback(stream_chunk) replies = response.texts metadata_dict["finish_reason"] = response.finish_reason metadata = [metadata_dict] @@ -161,6 +161,15 @@ def run(self, prompt: str): self._check_truncated_answers(metadata) return {"replies": replies, "meta": metadata} + def _build_chunk(self, chunk) -> StreamingChunk: + """ + Converts the response from the Cohere API to a StreamingChunk. + :param chunk: The chunk returned by the OpenAI API. + :return: The StreamingChunk. + """ + streaming_chunk = StreamingChunk(content=chunk.text, meta={"index": chunk.index}) + return streaming_chunk + def _check_truncated_answers(self, metadata: List[Dict[str, Any]]): """ Check the `finish_reason` returned with the Cohere response. From 3ed8fc3788a678939b74ce161dd813d00d2037cd Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 25 Jan 2024 15:46:30 +0100 Subject: [PATCH 3/6] Pylint fix --- .../components/generators/cohere/__init__.py | 3 +-- .../components/generators/cohere/generator.py | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py index 43f8201fe..93c0947e4 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py @@ -1,8 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from .generator import CohereGenerator from .chat.chat_generator import CohereChatGenerator +from .generator import CohereGenerator __all__ = ["CohereGenerator", "CohereChatGenerator"] - diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index 31a86a948..fee410eab 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -6,11 +6,12 @@ import sys from typing import Any, Callable, Dict, List, Optional, cast -from cohere import COHERE_API_URL, Client -from cohere.responses import Generations from haystack import DeserializationError, component, default_from_dict, default_to_dict from haystack.dataclasses import StreamingChunk +from cohere import COHERE_API_URL, Client +from cohere.responses import Generations + logger = logging.getLogger(__name__) From c4e3daaf16e29d714347962c4ab48457f5213c70 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 25 Jan 2024 15:59:19 +0100 Subject: [PATCH 4/6] Update test path --- integrations/cohere/tests/test_cohere_chat_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index cc360f5c9..c91ada419 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -5,7 +5,7 @@ import pytest from haystack.components.generators.utils import default_streaming_callback from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk -from haystack_integrations.components.generators.cohere.chat import CohereChatGenerator +from haystack_integrations.components.generators.cohere import CohereChatGenerator pytestmark = pytest.mark.chat_generators From a2cb24441f1a2bbae67bd2ef7daa0623c9fccdfd Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 25 Jan 2024 17:46:57 +0100 Subject: [PATCH 5/6] Keep licence in __init__.py --- .../components/generators/cohere/chat/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py index e69de29bb..49fd5f144 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file From e6b07f311e41f7d43da7b686f3eb421b5a8737b4 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 25 Jan 2024 17:48:14 +0100 Subject: [PATCH 6/6] Pylint newline --- .../components/generators/cohere/chat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py index 49fd5f144..e873bc332 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py @@ -1,3 +1,3 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # -# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file +# SPDX-License-Identifier: Apache-2.0