Skip to content

Commit

Permalink
update to latest haystack-ai version (#348)
Browse files Browse the repository at this point in the history
  • Loading branch information
masci authored Feb 7, 2024
1 parent 30164b8 commit 6365bae
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 20 deletions.
16 changes: 8 additions & 8 deletions integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import cohere
import pytest
from haystack.components.generators.utils import default_streaming_callback
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk
from haystack_integrations.components.generators.cohere import CohereChatGenerator

Expand Down Expand Up @@ -72,13 +72,13 @@ def test_init_with_parameters(self):
component = CohereChatGenerator(
api_key="test-api-key",
model="command-nightly",
streaming_callback=default_streaming_callback,
streaming_callback=print_streaming_chunk,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
assert component.api_key == "test-api-key"
assert component.model == "command-nightly"
assert component.streaming_callback is default_streaming_callback
assert component.streaming_callback is print_streaming_chunk
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}

Expand All @@ -101,7 +101,7 @@ def test_to_dict_with_parameters(self):
component = CohereChatGenerator(
api_key="test-api-key",
model="command-nightly",
streaming_callback=default_streaming_callback,
streaming_callback=print_streaming_chunk,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
Expand All @@ -110,7 +110,7 @@ def test_to_dict_with_parameters(self):
"type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator",
"init_parameters": {
"model": "command-nightly",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"api_base_url": "test-base-url",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
Expand Down Expand Up @@ -144,13 +144,13 @@ def test_from_dict(self, monkeypatch):
"init_parameters": {
"model": "command",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
component = CohereChatGenerator.from_dict(data)
assert component.model == "command"
assert component.streaming_callback is default_streaming_callback
assert component.streaming_callback is print_streaming_chunk
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}

Expand All @@ -162,7 +162,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
"init_parameters": {
"model": "command",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
Expand Down
17 changes: 5 additions & 12 deletions integrations/cohere/tests/test_cohere_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,12 @@

import pytest
from cohere import COHERE_API_URL
from haystack.components.generators.utils import print_streaming_chunk
from haystack_integrations.components.generators.cohere import CohereGenerator

pytestmark = pytest.mark.generators


def default_streaming_callback(chunk):
"""
Default callback function for streaming responses from Cohere API.
Prints the tokens of the first completion to stdout as soon as they are received and returns the chunk unchanged.
"""
print(chunk.text, flush=True, end="") # noqa: T201


class TestCohereGenerator:
def test_init_default(self):
component = CohereGenerator(api_key="test-api-key")
Expand Down Expand Up @@ -61,7 +54,7 @@ def test_to_dict_with_parameters(self):
model="command-light",
max_tokens=10,
some_test_param="test-params",
streaming_callback=default_streaming_callback,
streaming_callback=print_streaming_chunk,
api_base_url="test-base-url",
)
data = component.to_dict()
Expand All @@ -72,7 +65,7 @@ def test_to_dict_with_parameters(self):
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "tests.test_cohere_generators.default_streaming_callback",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
},
}

Expand Down Expand Up @@ -106,13 +99,13 @@ def test_from_dict(self, monkeypatch):
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "tests.test_cohere_generators.default_streaming_callback",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
},
}
component: CohereGenerator = CohereGenerator.from_dict(data)
assert component.api_key == "test-key"
assert component.model == "command"
assert component.streaming_callback == default_streaming_callback
assert component.streaming_callback == print_streaming_chunk
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}

Expand Down

0 comments on commit 6365bae

Please sign in to comment.