Skip to content

Commit

Permalink
fix: Cohere namespace reorg (#271)
Browse files Browse the repository at this point in the history
* Place both generators under cohere namespace

* Fix issues found in pre-release checks

* Pylint fix

* Update test path

* Keep licence in __init__.py

* Pylint newline
  • Loading branch information
vblagoje authored Jan 25, 2024
1 parent 8db28ee commit 1bbce8b
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from .chat.chat_generator import CohereChatGenerator
from .generator import CohereGenerator

__all__ = ["CohereGenerator"]
__all__ = ["CohereGenerator", "CohereChatGenerator"]
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from .chat_generator import CohereChatGenerator

__all__ = ["CohereChatGenerator"]
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
logger = logging.getLogger(__name__)


@component
class CohereChatGenerator:
"""Enables text generation using Cohere's chat endpoint. This component is designed to inference
Cohere's chat models.
Expand Down Expand Up @@ -123,10 +124,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereChatGenerator":
return default_from_dict(cls, data)

def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]:
if message.role == ChatRole.USER:
role = "User"
elif message.role == ChatRole.ASSISTANT:
role = "Chatbot"
role = "User" if message.role == ChatRole.USER else "Chatbot"
chat_message = {"user_name": role, "text": message.content}
return chat_message

Expand Down Expand Up @@ -179,7 +177,6 @@ def _build_chunk(self, chunk) -> StreamingChunk:
:param choice: The choice returned by the OpenAI API.
:return: The StreamingChunk.
"""
# if chunk.event_type == "text-generation":
chat_message = StreamingChunk(content=chunk.text, meta={"index": chunk.index, "event_type": chunk.event_type})
return chat_message

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Callable, Dict, List, Optional, cast

from haystack import DeserializationError, component, default_from_dict, default_to_dict
from haystack.dataclasses import StreamingChunk

from cohere import COHERE_API_URL, Client
from cohere.responses import Generations
Expand Down Expand Up @@ -148,8 +149,8 @@ def run(self, prompt: str):
if self.streaming_callback:
metadata_dict: Dict[str, Any] = {}
for chunk in response:
self.streaming_callback(chunk)
metadata_dict["index"] = chunk.index
stream_chunk = self._build_chunk(chunk)
self.streaming_callback(stream_chunk)
replies = response.texts
metadata_dict["finish_reason"] = response.finish_reason
metadata = [metadata_dict]
Expand All @@ -161,6 +162,15 @@ def run(self, prompt: str):
self._check_truncated_answers(metadata)
return {"replies": replies, "meta": metadata}

def _build_chunk(self, chunk) -> StreamingChunk:
"""
Converts the response from the Cohere API to a StreamingChunk.
:param chunk: The chunk returned by the OpenAI API.
:return: The StreamingChunk.
"""
streaming_chunk = StreamingChunk(content=chunk.text, meta={"index": chunk.index})
return streaming_chunk

def _check_truncated_answers(self, metadata: List[Dict[str, Any]]):
"""
Check the `finish_reason` returned with the Cohere response.
Expand Down
2 changes: 1 addition & 1 deletion integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from haystack.components.generators.utils import default_streaming_callback
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk
from haystack_integrations.components.generators.cohere.chat import CohereChatGenerator
from haystack_integrations.components.generators.cohere import CohereChatGenerator

pytestmark = pytest.mark.chat_generators

Expand Down

0 comments on commit 1bbce8b

Please sign in to comment.