Skip to content

Commit

Permalink
fixing minor linting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
TuanaCelik committed Feb 15, 2024
1 parent e0dee40 commit 08845fa
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .document_embedder import MistralDocumentEmbedder
from .text_embedder import MistralTextEmbedder

__all__ = ["MistralDocumentEmbedder", "MistralTextEmbedder"]
__all__ = ["MistralDocumentEmbedder", "MistralTextEmbedder"]
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,16 @@ def __init__(
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""
super(MistralDocumentEmbedder, self).__init__(api_key,
model,
dimensions,
api_base_url,
organization,
prefix,
suffix,
batch_size,
progress_bar,
meta_fields_to_embed,
embedding_separator,
)
super(MistralDocumentEmbedder, self).__init__(
api_key,
model,
dimensions,
api_base_url,
organization,
prefix,
suffix,
batch_size,
progress_bar,
meta_fields_to_embed,
embedding_separator,
)
23 changes: 12 additions & 11 deletions integrations/mistral/tests/test_mistral_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import os

import pytest
from datetime import datetime
from typing import Iterator
from unittest.mock import patch
from openai import OpenAIError
from openai import Stream
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import ChoiceDelta, Choice

from haystack.utils.auth import Secret
from haystack_integrations.components.generators.mistral.chat.chat_generator import MistralChatGenerator
import pytest
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.utils.auth import Secret
from haystack_integrations.components.generators.mistral.chat.chat_generator import MistralChatGenerator
from openai import OpenAIError, Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta


@pytest.fixture
Expand All @@ -22,6 +20,7 @@ def chat_messages():
ChatMessage.from_user("What's the capital of France"),
]


@pytest.fixture
def mock_chat_completion():
"""
Expand All @@ -47,6 +46,7 @@ def mock_chat_completion():
mock_chat_completion_create.return_value = completion
yield mock_chat_completion_create


@pytest.fixture
def mock_chat_completion_chunk():
"""
Expand Down Expand Up @@ -78,6 +78,7 @@ def __stream__(self) -> Iterator[ChatCompletionChunk]:
mock_chat_completion_create.return_value = MockStream(completion, cast_to=None, response=None, client=None)
yield mock_chat_completion_create


class TestMistralChatGenerator:
def test_init_default(self, monkeypatch):
monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key")
Expand Down Expand Up @@ -142,7 +143,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
}

def test_from_dict(self, monkeypatch):
monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key")
Expand Down Expand Up @@ -179,7 +180,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
MistralChatGenerator.from_dict(data)

def test_run(self, chat_messages, mock_chat_completion):
def test_run(self, chat_messages):
component = MistralChatGenerator(api_key=Secret.from_token("test-api-key"))
response = component.run(chat_messages)

Expand Down Expand Up @@ -208,7 +209,7 @@ def test_run_with_params(self, chat_messages, mock_chat_completion):
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]

def test_run_with_params_streaming(self, chat_messages, mock_chat_completion_chunk):
def test_run_with_params_streaming(self, chat_messages):
streaming_callback_called = False

def streaming_callback(chunk: StreamingChunk) -> None:
Expand Down
5 changes: 4 additions & 1 deletion integrations/mistral/tests/test_mistral_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def test_run(self):
def test_run_wrong_input_format(self):
embedder = MistralDocumentEmbedder(api_key=Secret.from_token("test-api-key"))

match_error_msg = "OpenAIDocumentEmbedder expects a list of Documents as input.In case you want to embed a string, please use the OpenAITextEmbedder."
match_error_msg = (
"OpenAIDocumentEmbedder expects a list of Documents as input.In case you want to embed a string, "
"please use the OpenAITextEmbedder."
)

with pytest.raises(TypeError, match=match_error_msg):
embedder.run(documents="text")
Expand Down
10 changes: 6 additions & 4 deletions integrations/mistral/tests/test_mistral_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os

import pytest
from haystack import Document
from haystack.utils import Secret
from haystack_integrations.components.embedders.mistral.text_embedder import MistralTextEmbedder

Expand Down Expand Up @@ -79,10 +78,13 @@ def test_run(self):
result = embedder.run(text)
assert all(isinstance(x, float) for x in result["embedding"])


def test_run_wrong_input_format(self):
embedder = MistralTextEmbedder(api_key=Secret.from_token("test-api-key"))
list_integers_input = ["text_snippet_1", "text_snippet_2"]
match_error_msg = "OpenAITextEmbedder expects a string as an input.In case you want to embed a list of Documents, please use the OpenAIDocumentEmbedder."
match_error_msg = (
"OpenAITextEmbedder expects a string as an input.In case you want to embed a list of Documents,"
" please use the OpenAIDocumentEmbedder."
)

with pytest.raises(TypeError, match=match_error_msg):
embedder.run(text=list_integers_input)
embedder.run(text=list_integers_input)

0 comments on commit 08845fa

Please sign in to comment.