Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: CohereGenerator #6034

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
0cbf196
added CohereGenerator with unit tests
sunilkumardash9 Oct 12, 2023
3576a4a
Merge branch 'main' into add-cohere-generator
sunilkumardash9 Oct 12, 2023
2f48f60
Merge branch 'main' into add-cohere-generator
masci Oct 13, 2023
bdaf1c2
1. added releasenote
sunilkumardash9 Oct 13, 2023
b518339
Merge branch 'main' into add-cohere-generator
sunilkumardash9 Oct 13, 2023
e2547aa
Merge remote-tracking branch 'origin/add-cohere-generator' into add-c…
sunilkumardash9 Oct 13, 2023
127e2c7
1. move client creation to __init__
sunilkumardash9 Oct 13, 2023
bcd855e
few fixes
sunilkumardash9 Oct 13, 2023
3fa2f0d
add cohere to git workflows
sunilkumardash9 Oct 13, 2023
f58e602
1. CohereGenerator as top level import in generators
sunilkumardash9 Oct 13, 2023
e074a5e
1. corrected git workflow files for cohere import
sunilkumardash9 Oct 16, 2023
44a3f14
Merge branch 'main' into add-cohere-generator
ZanSara Oct 16, 2023
106e7fc
added cohere in missed out workflow installs
sunilkumardash9 Oct 16, 2023
3b4574d
Merge branch 'main' into add-cohere-generator
ZanSara Oct 16, 2023
4ece817
1. Removed default_streaming_callback from cohere.py and added in test.
sunilkumardash9 Oct 19, 2023
bb63179
Update haystack/preview/components/generators/cohere/cohere.py
sunilkumardash9 Oct 20, 2023
35a1304
Update haystack/preview/components/generators/cohere/cohere.py
sunilkumardash9 Oct 20, 2023
24359a0
Update haystack/preview/components/generators/cohere/cohere.py
sunilkumardash9 Oct 20, 2023
6449bd2
Update haystack/preview/components/generators/cohere/cohere.py
sunilkumardash9 Oct 20, 2023
b43c698
Update haystack/preview/components/generators/cohere/cohere.py
sunilkumardash9 Oct 20, 2023
282ebab
Merge branch 'main' into add-cohere-generator
ZanSara Nov 23, 2023
e51d787
move out of folder
ZanSara Nov 23, 2023
2378802
Merge branch 'main' into add-cohere-generator
ZanSara Nov 23, 2023
1d8c849
black
ZanSara Nov 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
145 changes: 145 additions & 0 deletions haystack/preview/components/generators/cohere/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import logging
import sys
from collections import defaultdict
from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict, List, Optional

from haystack.lazy_imports import LazyImport

with LazyImport(message="Run 'pip install cohere'") as cohere_import:
import cohere
from haystack.preview import (DeserializationError, component,
default_from_dict, default_to_dict)

logger = logging.getLogger(__name__)


API_BASE_URL = 'https://api.cohere.ai'


def default_streaming_callback(chunk):
"""
Default callback function for streaming responses from Cohere API.
Prints the tokens of the first completion to stdout as soon as they are received and returns the chunk unchanged.
"""
print(chunk.text, flush=True, end="")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this function, because it's used only in the tests. To test for a streaming callback, you can define it in the tests themselves.


@component
class CohereGenerator:
"""Cohere Generator compatible with cohere chat"""
def __init__(
self,
api_key: str,
model: str = "command",
streaming_callback: Optional[Callable] = None,
api_base_url: str = API_BASE_URL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the tests you're using cohere.COHERE_API_URL. Can we use the same here instead of duplicating its content into a constant? In this way, if cohere updates this value we won't need to do anything and the component will still work.

**kwargs
):
"""
Args:
api_key (str): The API key for the Cohere API.
model_name (str): The name of the model to use.
streaming_callback (Callable, optional): A callback function to be called with the streaming response. Defaults to None.
"""
self.api_key = api_key
self.model = model
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
self.model_parameters = kwargs

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
if self.streaming_callback:
module = self.streaming_callback.__module__
if module == "builtins":
callback_name = self.streaming_callback.__name__
else:
callback_name = f"{module}.{self.streaming_callback.__name__}"
else:
callback_name = None

return default_to_dict(
self,
api_key=self.api_key,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
**self.model_parameters,
)


@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
streaming_callback = None
if "streaming_callback" in init_params and init_params["streaming_callback"]:
parts = init_params["streaming_callback"].split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module = sys.modules.get(module_name, None)
if not module:
raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}")
streaming_callback = getattr(module, function_name, None)
if not streaming_callback:
raise DeserializationError(f"Could not locate the streaming callback: {function_name}")
data["init_parameters"]["streaming_callback"] = streaming_callback
return default_from_dict(cls, data)

@component.output_types(replies=List[str], metadata=List[Dict[str, Any]])
def run(self, prompt: str):
"""
Queries the LLM with the prompts to produce replies.

:param prompts: The prompts to be sent to the generative model.
"""
co = cohere.Client(
api_key = self.api_key,
api_url = self.api_base_url,
)
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
response = co.generate(
model = self.model,
prompt = prompt,
stream=self.streaming_callback is not None,
**self.model_parameters,
)
replies: List[str]
metadata: List[Dict[str, Any]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure we need these type definitions? Does mypy raises errors if you remove these two lines?

if self.streaming_callback:
metadata_dict: Dict[str, Any] = {}
for chunk in response:
self.streaming_callback(chunk)
metadata_dict["index"] = chunk.index
replies = response.texts
metadata_dict["finish_reason"] = response.finish_reason
metadata = [dict(metadata_dict)]
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
self._check_truncated_answers(metadata)
return {"replies": replies, "metadata": metadata}

metadata = [
{
"finish_reason": response[0].finish_reason,
}
]
replies = [response[0].text]
self._check_truncated_answers(metadata)
return {"replies": replies, "metadata": metadata}

def _check_truncated_answers(self, metadata: List[Dict[str, Any]]):
"""
Check the `finish_reason` returned with the Cohere response.
If the `finish_reason` is `MAX_TOKEN`, log a warning to the user.
"""
if metadata[0]["finish_reason"]=="MAX_TOKENS":
logger.warning(
"Responses have been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions.",
)




189 changes: 189 additions & 0 deletions test/preview/components/generators/cohere/test_cohere_generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import os
from unittest.mock import patch, Mock
from copy import deepcopy

import pytest
import cohere
from cohere.responses.generation import StreamingText

from haystack.preview.components.generators.cohere.cohere import CohereGenerator
from haystack.preview.components.generators.cohere.cohere import default_streaming_callback


class TestGPTGenerator:
@pytest.mark.unit
def test_init_default(self):
component = CohereGenerator(api_key="test-api-key")
assert component.api_key == "test-api-key"
assert component.model == "command"
assert component.streaming_callback is None
assert component.api_base_url == cohere.COHERE_API_URL
assert component.model_parameters == {}

@pytest.mark.unit
def test_init_with_parameters(self):
callback = lambda x: x
component = CohereGenerator(
api_key="test-api-key",
model="command-light",
max_tokens=10,
some_test_param="test-params",
streaming_callback=callback,
api_base_url="test-base-url",
)
assert component.api_key == "test-api-key"
assert component.model == "command-light"
assert component.streaming_callback == callback
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}

@pytest.mark.unit
def test_to_dict_default(self):
component = CohereGenerator(api_key="test-api-key")
data = component.to_dict()
assert data == {
"type": "CohereGenerator",
"init_parameters": {
"api_key": "test-api-key",
"model": "command",
"streaming_callback": None,
"api_base_url": cohere.COHERE_API_URL,
},
}

@pytest.mark.unit
def test_to_dict_with_parameters(self):
component = CohereGenerator(
api_key="test-api-key",
model="command-light",
max_tokens=10,
some_test_param="test-params",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
"type": "CohereGenerator",
"init_parameters": {
"api_key": "test-api-key",
"model": "command-light",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.cohere.cohere.default_streaming_callback",
},
}
@pytest.mark.unit
def test_to_dict_with_lambda_streaming_callback(self):
component = CohereGenerator(
api_key="test-api-key",
model="command",
max_tokens=10,
some_test_param="test-params",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
'type': 'CohereGenerator',
'init_parameters': {
'api_key': 'test-api-key',
'model': 'command',
'streaming_callback': 'test_cohere_generators.<lambda>',
'api_base_url': 'test-base-url',
'max_tokens': 10,
'some_test_param': 'test-params'
}
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "CohereGenerator",
"init_parameters": {
"api_key": "test-api-key",
"model": "command",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.cohere.cohere.default_streaming_callback",
},
}
component = CohereGenerator.from_dict(data)
assert component.api_key == "test-api-key"
assert component.model == "command"
assert component.streaming_callback == default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}

# @pytest.mark.unit
# def test_run_with_parameters(self):
# with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as cohere_patch:
# cohere_patch.create.side_effect = mock_openai_response
# component = CohereGenerator(api_key="test-api-key", max_tokens=10)
# component.run(prompt="test-prompt-1")
# gpt_patch.create.assert_called_once_with(
# model="gpt-3.5-turbo",
# api_key="test-api-key",
# messages=[{"role": "user", "content": "test-prompt-1"}],
# stream=False,
# max_tokens=10,
# )

@pytest.mark.unit
def test_check_truncated_answers(self, caplog):
component = CohereGenerator(api_key="test-api-key")
metadata = [
{"finish_reason": "MAX_TOKENS"},
]
component._check_truncated_answers(metadata)
assert caplog.records[0].message == (
"Responses have been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions."
)

@pytest.mark.skipif(
not os.environ.get("CO_API_KEY", None),
reason="Export an env var called CO_API_KEY containing the Cohere API key to run this test.",
)
@pytest.mark.integration
def test_cohere_generator_run(self):
component = CohereGenerator(api_key=os.environ.get("CO_API_KEY"))
results = component.run(prompt="What's the capital of France?")
assert len(results["replies"]) == 1
assert "Paris" in results["replies"][0]
assert len(results["metadata"]) == 1
assert results["metadata"][0]["finish_reason"] == "COMPLETE"

@pytest.mark.skipif(
not os.environ.get("CO_API_KEY", None),
reason="Export an env var called CO_API_KEY containing the Cohere API key to run this test.",
)
@pytest.mark.integration
def test_cohere_generator_run_wrong_model_name(self):
component = CohereGenerator(model="something-obviously-wrong", api_key=os.environ.get("CO_API_KEY"))
with pytest.raises(cohere.CohereAPIError, match="model not found, make sure the correct model ID was used and that you have access to the model."):
component.run(prompt="What's the capital of France?")

@pytest.mark.skipif(
not os.environ.get("CO_API_KEY", None),
reason="Export an env var called CO_API_KEY containing the Cohere API key to run this test.",
)
@pytest.mark.integration
def test_cohere_generator_run_streaming(self):
class Callback:
def __init__(self):
self.responses = ""

def __call__(self, chunk):
self.responses += chunk.text
return chunk

callback = Callback()
component = CohereGenerator(os.environ.get("CO_API_KEY"), streaming_callback=callback)
results = component.run(prompt="What's the capital of France?")

assert len(results["replies"]) == 1
assert "Paris" in results["replies"][0]
assert len(results["metadata"]) == 1
assert results["metadata"][0]["finish_reason"] == "COMPLETE"
assert callback.responses == results["replies"][0]