diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py index c36f982df..93c0947e4 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from .chat.chat_generator import CohereChatGenerator from .generator import CohereGenerator -__all__ = ["CohereGenerator"] +__all__ = ["CohereGenerator", "CohereChatGenerator"] diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py index dc14c9c1c..e873bc332 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py @@ -1,6 +1,3 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from .chat_generator import CohereChatGenerator - -__all__ = ["CohereChatGenerator"] diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 0ff29ce14..c632bed83 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -12,6 +12,7 @@ logger = logging.getLogger(__name__) +@component class CohereChatGenerator: """Enables text generation using Cohere's chat endpoint. This component is designed to inference Cohere's chat models. @@ -123,10 +124,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereChatGenerator": return default_from_dict(cls, data) def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: - if message.role == ChatRole.USER: - role = "User" - elif message.role == ChatRole.ASSISTANT: - role = "Chatbot" + role = "User" if message.role == ChatRole.USER else "Chatbot" chat_message = {"user_name": role, "text": message.content} return chat_message @@ -179,7 +177,6 @@ def _build_chunk(self, chunk) -> StreamingChunk: :param choice: The choice returned by the OpenAI API. :return: The StreamingChunk. """ - # if chunk.event_type == "text-generation": chat_message = StreamingChunk(content=chunk.text, meta={"index": chunk.index, "event_type": chunk.event_type}) return chat_message diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index 7bca3ed9f..fee410eab 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, cast from haystack import DeserializationError, component, default_from_dict, default_to_dict +from haystack.dataclasses import StreamingChunk from cohere import COHERE_API_URL, Client from cohere.responses import Generations @@ -148,8 +149,8 @@ def run(self, prompt: str): if self.streaming_callback: metadata_dict: Dict[str, Any] = {} for chunk in response: - self.streaming_callback(chunk) - metadata_dict["index"] = chunk.index + stream_chunk = self._build_chunk(chunk) + self.streaming_callback(stream_chunk) replies = response.texts metadata_dict["finish_reason"] = response.finish_reason metadata = [metadata_dict] @@ -161,6 +162,15 @@ def run(self, prompt: str): self._check_truncated_answers(metadata) return {"replies": replies, "meta": metadata} + def _build_chunk(self, chunk) -> StreamingChunk: + """ + Converts the response from the Cohere API to a StreamingChunk. + :param chunk: The chunk returned by the OpenAI API. + :return: The StreamingChunk. + """ + streaming_chunk = StreamingChunk(content=chunk.text, meta={"index": chunk.index}) + return streaming_chunk + def _check_truncated_answers(self, metadata: List[Dict[str, Any]]): """ Check the `finish_reason` returned with the Cohere response. diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index cc360f5c9..c91ada419 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -5,7 +5,7 @@ import pytest from haystack.components.generators.utils import default_streaming_callback from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk -from haystack_integrations.components.generators.cohere.chat import CohereChatGenerator +from haystack_integrations.components.generators.cohere import CohereChatGenerator pytestmark = pytest.mark.chat_generators