Skip to content

Commit

Permalink
refactor tests (#487)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Feb 26, 2024
1 parent b1f8b07 commit 8e4db9e
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 29 deletions.
27 changes: 4 additions & 23 deletions integrations/mistral/tests/test_mistral_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
MistralChatGenerator.from_dict(data)

def test_run(self, chat_messages):
def test_run(self, chat_messages, mock_chat_completion, monkeypatch): # noqa: ARG002
monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key")
component = MistralChatGenerator()
response = component.run(chat_messages)

Expand All @@ -158,7 +159,8 @@ def test_run(self, chat_messages):
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]

def test_run_with_params(self, chat_messages, mock_chat_completion):
def test_run_with_params(self, chat_messages, mock_chat_completion, monkeypatch):
monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key")
component = MistralChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5})
response = component.run(chat_messages)

Expand All @@ -174,27 +176,6 @@ def test_run_with_params(self, chat_messages, mock_chat_completion):
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]

def test_run_with_params_streaming(self, chat_messages):
streaming_callback_called = False

def streaming_callback(chunk: StreamingChunk) -> None: # noqa: ARG001
nonlocal streaming_callback_called
streaming_callback_called = True

component = MistralChatGenerator(streaming_callback=streaming_callback)
response = component.run(chat_messages)

# check we called the streaming callback
assert streaming_callback_called

# check that the component still returns the correct response
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Paris" in response["replies"][0].content # see mock_chat_completion_chunk

def test_check_abnormal_completions(self, caplog):
component = MistralChatGenerator(api_key=Secret.from_token("test-api-key"))
messages = [
Expand Down
10 changes: 7 additions & 3 deletions integrations/mistral/tests/test_mistral_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@


class TestMistralDocumentEmbedder:
def test_init_default(self):
def test_init_default(self, monkeypatch):
monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key")

embedder = MistralDocumentEmbedder()
assert embedder.api_key == Secret.from_env_var(["MISTRAL_API_KEY"])
assert embedder.model == "mistral-embed"
Expand Down Expand Up @@ -46,7 +48,9 @@ def test_init_with_parameters(self):
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == "-"

def test_to_dict(self):
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key")

embedder_component = MistralDocumentEmbedder()
component_dict = embedder_component.to_dict()
assert component_dict == {
Expand Down Expand Up @@ -99,7 +103,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):

@pytest.mark.skipif(
not os.environ.get("MISTRAL_API_KEY", None),
reason="Export an env var called MISTRAL_API_KEY containing the Cohere API key to run this test.",
reason="Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test.",
)
@pytest.mark.integration
def test_run(self):
Expand Down
14 changes: 11 additions & 3 deletions integrations/mistral/tests/test_mistral_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@


class TestMistralTextEmbedder:
def test_init_default(self):
def test_init_default(self, monkeypatch):
monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key")

embedder = MistralTextEmbedder()
assert embedder.api_key == Secret.from_env_var(["MISTRAL_API_KEY"])
assert embedder.api_base_url == "https://api.mistral.ai/v1"
assert embedder.model == "mistral-embed"
assert embedder.prefix == ""
assert embedder.suffix == ""
Expand All @@ -26,18 +29,22 @@ def test_init_with_parameters(self):
suffix="END",
)
assert embedder.api_key == Secret.from_token("test-api-key")
assert embedder.api_base_url == "https://api.mistral.ai/v1"
assert embedder.model == "mistral-embed-v2"
assert embedder.prefix == "START"
assert embedder.suffix == "END"

def test_to_dict(self):
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key")

embedder_component = MistralTextEmbedder()
component_dict = embedder_component.to_dict()
assert component_dict == {
"type": "haystack_integrations.components.embedders.mistral.text_embedder.MistralTextEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"},
"model": "mistral-embed",
"api_base_url": "https://api.mistral.ai/v1",
"dimensions": None,
"organization": None,
"prefix": "",
Expand All @@ -60,6 +67,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
"init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "mistral-embed-v2",
"api_base_url": "https://custom-api-base-url.com",
"dimensions": None,
"organization": None,
"prefix": "START",
Expand All @@ -69,7 +77,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):

@pytest.mark.skipif(
not os.environ.get("MISTRAL_API_KEY", None),
reason="Export an env var called MISTRAL_API_KEY containing the Cohere API key to run this test.",
reason="Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test.",
)
@pytest.mark.integration
def test_run(self):
Expand Down

0 comments on commit 8e4db9e

Please sign in to comment.