-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: CohereGenerator #6034
feat: CohereGenerator #6034
Changes from 3 commits
0cbf196
3576a4a
2f48f60
bdaf1c2
b518339
e2547aa
127e2c7
bcd855e
3fa2f0d
f58e602
e074a5e
44a3f14
106e7fc
3b4574d
4ece817
bb63179
35a1304
24359a0
6449bd2
b43c698
282ebab
e51d787
2378802
1d8c849
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import logging | ||
import sys | ||
from collections import defaultdict | ||
from dataclasses import asdict, dataclass | ||
from typing import Any, Callable, Dict, List, Optional | ||
|
||
from haystack.lazy_imports import LazyImport | ||
|
||
with LazyImport(message="Run 'pip install cohere'") as cohere_import: | ||
import cohere | ||
from haystack.preview import (DeserializationError, component, | ||
default_from_dict, default_to_dict) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
API_BASE_URL = 'https://api.cohere.ai' | ||
|
||
|
||
def default_streaming_callback(chunk): | ||
""" | ||
Default callback function for streaming responses from Cohere API. | ||
Prints the tokens of the first completion to stdout as soon as they are received and returns the chunk unchanged. | ||
""" | ||
print(chunk.text, flush=True, end="") | ||
|
||
@component | ||
class CohereGenerator: | ||
"""Cohere Generator compatible with cohere chat""" | ||
def __init__( | ||
self, | ||
api_key: str, | ||
model: str = "command", | ||
streaming_callback: Optional[Callable] = None, | ||
api_base_url: str = API_BASE_URL, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the tests you're using |
||
**kwargs | ||
): | ||
""" | ||
Args: | ||
api_key (str): The API key for the Cohere API. | ||
model_name (str): The name of the model to use. | ||
streaming_callback (Callable, optional): A callback function to be called with the streaming response. Defaults to None. | ||
""" | ||
self.api_key = api_key | ||
self.model = model | ||
self.streaming_callback = streaming_callback | ||
self.api_base_url = api_base_url | ||
self.model_parameters = kwargs | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary. | ||
""" | ||
if self.streaming_callback: | ||
module = self.streaming_callback.__module__ | ||
if module == "builtins": | ||
callback_name = self.streaming_callback.__name__ | ||
else: | ||
callback_name = f"{module}.{self.streaming_callback.__name__}" | ||
else: | ||
callback_name = None | ||
|
||
return default_to_dict( | ||
self, | ||
api_key=self.api_key, | ||
model=self.model, | ||
streaming_callback=callback_name, | ||
api_base_url=self.api_base_url, | ||
**self.model_parameters, | ||
) | ||
|
||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator": | ||
""" | ||
Deserialize this component from a dictionary. | ||
""" | ||
init_params = data.get("init_parameters", {}) | ||
streaming_callback = None | ||
if "streaming_callback" in init_params and init_params["streaming_callback"]: | ||
parts = init_params["streaming_callback"].split(".") | ||
module_name = ".".join(parts[:-1]) | ||
function_name = parts[-1] | ||
module = sys.modules.get(module_name, None) | ||
if not module: | ||
raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}") | ||
streaming_callback = getattr(module, function_name, None) | ||
if not streaming_callback: | ||
raise DeserializationError(f"Could not locate the streaming callback: {function_name}") | ||
data["init_parameters"]["streaming_callback"] = streaming_callback | ||
return default_from_dict(cls, data) | ||
|
||
@component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) | ||
def run(self, prompt: str): | ||
""" | ||
Queries the LLM with the prompts to produce replies. | ||
|
||
:param prompts: The prompts to be sent to the generative model. | ||
""" | ||
co = cohere.Client( | ||
api_key = self.api_key, | ||
api_url = self.api_base_url, | ||
) | ||
ZanSara marked this conversation as resolved.
Show resolved
Hide resolved
|
||
response = co.generate( | ||
model = self.model, | ||
prompt = prompt, | ||
stream=self.streaming_callback is not None, | ||
**self.model_parameters, | ||
) | ||
replies: List[str] | ||
metadata: List[Dict[str, Any]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure we need these type definitions? Does mypy raises errors if you remove these two lines? |
||
if self.streaming_callback: | ||
metadata_dict: Dict[str, Any] = {} | ||
for chunk in response: | ||
self.streaming_callback(chunk) | ||
metadata_dict["index"] = chunk.index | ||
replies = response.texts | ||
metadata_dict["finish_reason"] = response.finish_reason | ||
metadata = [dict(metadata_dict)] | ||
ZanSara marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._check_truncated_answers(metadata) | ||
return {"replies": replies, "metadata": metadata} | ||
|
||
metadata = [ | ||
{ | ||
"finish_reason": response[0].finish_reason, | ||
} | ||
] | ||
replies = [response[0].text] | ||
self._check_truncated_answers(metadata) | ||
return {"replies": replies, "metadata": metadata} | ||
|
||
def _check_truncated_answers(self, metadata: List[Dict[str, Any]]): | ||
""" | ||
Check the `finish_reason` returned with the Cohere response. | ||
If the `finish_reason` is `MAX_TOKEN`, log a warning to the user. | ||
""" | ||
if metadata[0]["finish_reason"]=="MAX_TOKENS": | ||
logger.warning( | ||
"Responses have been truncated before reaching a natural stopping point. " | ||
"Increase the max_tokens parameter to allow for longer completions.", | ||
) | ||
|
||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import os | ||
from unittest.mock import patch, Mock | ||
from copy import deepcopy | ||
|
||
import pytest | ||
import cohere | ||
from cohere.responses.generation import StreamingText | ||
|
||
from haystack.preview.components.generators.cohere.cohere import CohereGenerator | ||
from haystack.preview.components.generators.cohere.cohere import default_streaming_callback | ||
|
||
|
||
class TestGPTGenerator: | ||
@pytest.mark.unit | ||
def test_init_default(self): | ||
component = CohereGenerator(api_key="test-api-key") | ||
assert component.api_key == "test-api-key" | ||
assert component.model == "command" | ||
assert component.streaming_callback is None | ||
assert component.api_base_url == cohere.COHERE_API_URL | ||
assert component.model_parameters == {} | ||
|
||
@pytest.mark.unit | ||
def test_init_with_parameters(self): | ||
callback = lambda x: x | ||
component = CohereGenerator( | ||
api_key="test-api-key", | ||
model="command-light", | ||
max_tokens=10, | ||
some_test_param="test-params", | ||
streaming_callback=callback, | ||
api_base_url="test-base-url", | ||
) | ||
assert component.api_key == "test-api-key" | ||
assert component.model == "command-light" | ||
assert component.streaming_callback == callback | ||
assert component.api_base_url == "test-base-url" | ||
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} | ||
|
||
@pytest.mark.unit | ||
def test_to_dict_default(self): | ||
component = CohereGenerator(api_key="test-api-key") | ||
data = component.to_dict() | ||
assert data == { | ||
"type": "CohereGenerator", | ||
"init_parameters": { | ||
"api_key": "test-api-key", | ||
"model": "command", | ||
"streaming_callback": None, | ||
"api_base_url": cohere.COHERE_API_URL, | ||
}, | ||
} | ||
|
||
@pytest.mark.unit | ||
def test_to_dict_with_parameters(self): | ||
component = CohereGenerator( | ||
api_key="test-api-key", | ||
model="command-light", | ||
max_tokens=10, | ||
some_test_param="test-params", | ||
streaming_callback=default_streaming_callback, | ||
api_base_url="test-base-url", | ||
) | ||
data = component.to_dict() | ||
assert data == { | ||
"type": "CohereGenerator", | ||
"init_parameters": { | ||
"api_key": "test-api-key", | ||
"model": "command-light", | ||
"max_tokens": 10, | ||
"some_test_param": "test-params", | ||
"api_base_url": "test-base-url", | ||
"streaming_callback": "haystack.preview.components.generators.cohere.cohere.default_streaming_callback", | ||
}, | ||
} | ||
@pytest.mark.unit | ||
def test_to_dict_with_lambda_streaming_callback(self): | ||
component = CohereGenerator( | ||
api_key="test-api-key", | ||
model="command", | ||
max_tokens=10, | ||
some_test_param="test-params", | ||
streaming_callback=lambda x: x, | ||
api_base_url="test-base-url", | ||
) | ||
data = component.to_dict() | ||
assert data == { | ||
'type': 'CohereGenerator', | ||
'init_parameters': { | ||
'api_key': 'test-api-key', | ||
'model': 'command', | ||
'streaming_callback': 'test_cohere_generators.<lambda>', | ||
'api_base_url': 'test-base-url', | ||
'max_tokens': 10, | ||
'some_test_param': 'test-params' | ||
} | ||
} | ||
@pytest.mark.unit | ||
def test_from_dict(self): | ||
data = { | ||
"type": "CohereGenerator", | ||
"init_parameters": { | ||
"api_key": "test-api-key", | ||
"model": "command", | ||
"max_tokens": 10, | ||
"some_test_param": "test-params", | ||
"api_base_url": "test-base-url", | ||
"streaming_callback": "haystack.preview.components.generators.cohere.cohere.default_streaming_callback", | ||
}, | ||
} | ||
component = CohereGenerator.from_dict(data) | ||
assert component.api_key == "test-api-key" | ||
assert component.model == "command" | ||
assert component.streaming_callback == default_streaming_callback | ||
assert component.api_base_url == "test-base-url" | ||
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} | ||
|
||
# @pytest.mark.unit | ||
# def test_run_with_parameters(self): | ||
# with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as cohere_patch: | ||
# cohere_patch.create.side_effect = mock_openai_response | ||
# component = CohereGenerator(api_key="test-api-key", max_tokens=10) | ||
# component.run(prompt="test-prompt-1") | ||
# gpt_patch.create.assert_called_once_with( | ||
# model="gpt-3.5-turbo", | ||
# api_key="test-api-key", | ||
# messages=[{"role": "user", "content": "test-prompt-1"}], | ||
# stream=False, | ||
# max_tokens=10, | ||
# ) | ||
|
||
@pytest.mark.unit | ||
def test_check_truncated_answers(self, caplog): | ||
component = CohereGenerator(api_key="test-api-key") | ||
metadata = [ | ||
{"finish_reason": "MAX_TOKENS"}, | ||
] | ||
component._check_truncated_answers(metadata) | ||
assert caplog.records[0].message == ( | ||
"Responses have been truncated before reaching a natural stopping point. " | ||
"Increase the max_tokens parameter to allow for longer completions." | ||
) | ||
|
||
@pytest.mark.skipif( | ||
not os.environ.get("CO_API_KEY", None), | ||
reason="Export an env var called CO_API_KEY containing the Cohere API key to run this test.", | ||
) | ||
@pytest.mark.integration | ||
def test_cohere_generator_run(self): | ||
component = CohereGenerator(api_key=os.environ.get("CO_API_KEY")) | ||
results = component.run(prompt="What's the capital of France?") | ||
assert len(results["replies"]) == 1 | ||
assert "Paris" in results["replies"][0] | ||
assert len(results["metadata"]) == 1 | ||
assert results["metadata"][0]["finish_reason"] == "COMPLETE" | ||
|
||
@pytest.mark.skipif( | ||
not os.environ.get("CO_API_KEY", None), | ||
reason="Export an env var called CO_API_KEY containing the Cohere API key to run this test.", | ||
) | ||
@pytest.mark.integration | ||
def test_cohere_generator_run_wrong_model_name(self): | ||
component = CohereGenerator(model="something-obviously-wrong", api_key=os.environ.get("CO_API_KEY")) | ||
with pytest.raises(cohere.CohereAPIError, match="model not found, make sure the correct model ID was used and that you have access to the model."): | ||
component.run(prompt="What's the capital of France?") | ||
|
||
@pytest.mark.skipif( | ||
not os.environ.get("CO_API_KEY", None), | ||
reason="Export an env var called CO_API_KEY containing the Cohere API key to run this test.", | ||
) | ||
@pytest.mark.integration | ||
def test_cohere_generator_run_streaming(self): | ||
class Callback: | ||
def __init__(self): | ||
self.responses = "" | ||
|
||
def __call__(self, chunk): | ||
self.responses += chunk.text | ||
return chunk | ||
|
||
callback = Callback() | ||
component = CohereGenerator(os.environ.get("CO_API_KEY"), streaming_callback=callback) | ||
results = component.run(prompt="What's the capital of France?") | ||
|
||
assert len(results["replies"]) == 1 | ||
assert "Paris" in results["replies"][0] | ||
assert len(results["metadata"]) == 1 | ||
assert results["metadata"][0]["finish_reason"] == "COMPLETE" | ||
assert callback.responses == results["replies"][0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this function, because it's used only in the tests. To test for a streaming callback, you can define it in the tests themselves.