Skip to content

Commit

Permalink
first refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Mar 12, 2024
1 parent f95e4d0 commit 27212f5
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 60 deletions.
30 changes: 0 additions & 30 deletions integrations/chroma/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do
assert doc_received.content == doc_expected.content
assert doc_received.meta == doc_expected.meta

@pytest.mark.unit
def test_ne_filter(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
"""
We customize this test because Chroma consider "not equal" true when
Expand All @@ -72,14 +71,12 @@ def test_ne_filter(self, document_store: ChromaDocumentStore, filterable_docs: L
result, [doc for doc in filterable_docs if doc.meta.get("page", "100") != "100"]
)

@pytest.mark.unit
def test_delete_empty(self, document_store: ChromaDocumentStore):
"""
Deleting a non-existing document should not raise with Chroma
"""
document_store.delete_documents(["test"])

@pytest.mark.unit
def test_delete_not_empty_nonexisting(self, document_store: ChromaDocumentStore):
"""
Deleting a non-existing document should not raise with Chroma
Expand Down Expand Up @@ -131,144 +128,117 @@ def test_same_collection_name_reinitialization(self):
ChromaDocumentStore("test_name")

@pytest.mark.skip(reason="Filter on array contents is not supported.")
@pytest.mark.unit
def test_filter_document_array(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter on dataframe contents is not supported.")
@pytest.mark.unit
def test_filter_document_dataframe(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter on table contents is not supported.")
@pytest.mark.unit
def test_eq_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter on embedding value is not supported.")
@pytest.mark.unit
def test_eq_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="$in operator is not supported.")
@pytest.mark.unit
def test_in_filter_explicit(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="$in operator is not supported. Filter on table contents is not supported.")
@pytest.mark.unit
def test_in_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="$in operator is not supported.")
@pytest.mark.unit
def test_in_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter on table contents is not supported.")
@pytest.mark.unit
def test_ne_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter on embedding value is not supported.")
@pytest.mark.unit
def test_ne_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="$nin operator is not supported. Filter on table contents is not supported.")
@pytest.mark.unit
def test_nin_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="$nin operator is not supported. Filter on embedding value is not supported.")
@pytest.mark.unit
def test_nin_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="$nin operator is not supported.")
@pytest.mark.unit
def test_nin_filter(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
@pytest.mark.unit
def test_filter_simple_implicit_and_with_multi_key_dict(
self, document_store: ChromaDocumentStore, filterable_docs: List[Document]
):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
@pytest.mark.unit
def test_filter_simple_explicit_and_with_multikey_dict(
self, document_store: ChromaDocumentStore, filterable_docs: List[Document]
):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
@pytest.mark.unit
def test_filter_simple_explicit_and_with_list(
self, document_store: ChromaDocumentStore, filterable_docs: List[Document]
):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
@pytest.mark.unit
def test_filter_simple_implicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
@pytest.mark.unit
def test_filter_nested_explicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
@pytest.mark.unit
def test_filter_nested_implicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
@pytest.mark.unit
def test_filter_simple_or(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
@pytest.mark.unit
def test_filter_nested_or(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter on table contents is not supported.")
@pytest.mark.unit
def test_filter_nested_and_or_explicit(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
@pytest.mark.unit
def test_filter_nested_and_or_implicit(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
@pytest.mark.unit
def test_filter_nested_or_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Filter syntax not supported.")
@pytest.mark.unit
def test_filter_nested_multiple_identical_operators_same_level(
self, document_store: ChromaDocumentStore, filterable_docs: List[Document]
):
pass

@pytest.mark.skip(reason="Duplicate policy not supported.")
@pytest.mark.unit
def test_write_duplicate_fail(self, document_store: ChromaDocumentStore):
pass

@pytest.mark.skip(reason="Duplicate policy not supported.")
@pytest.mark.unit
def test_write_duplicate_skip(self, document_store: ChromaDocumentStore):
pass

@pytest.mark.skip(reason="Duplicate policy not supported.")
@pytest.mark.unit
def test_write_duplicate_overwrite(self, document_store: ChromaDocumentStore):
pass
12 changes: 0 additions & 12 deletions integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def chat_messages():


class TestCohereChatGenerator:
@pytest.mark.unit
def test_init_default(self, monkeypatch):
monkeypatch.setenv("COHERE_API_KEY", "test-api-key")

Expand All @@ -64,14 +63,12 @@ def test_init_default(self, monkeypatch):
assert component.api_base_url == cohere.COHERE_API_URL
assert not component.generation_kwargs

@pytest.mark.unit
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("COHERE_API_KEY", raising=False)
monkeypatch.delenv("CO_API_KEY", raising=False)
with pytest.raises(ValueError):
CohereChatGenerator()

@pytest.mark.unit
def test_init_with_parameters(self):
component = CohereChatGenerator(
api_key=Secret.from_token("test-api-key"),
Expand All @@ -86,7 +83,6 @@ def test_init_with_parameters(self):
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}

@pytest.mark.unit
def test_to_dict_default(self, monkeypatch):
monkeypatch.setenv("COHERE_API_KEY", "test-api-key")
component = CohereChatGenerator()
Expand All @@ -102,7 +98,6 @@ def test_to_dict_default(self, monkeypatch):
},
}

@pytest.mark.unit
def test_to_dict_with_parameters(self, monkeypatch):
monkeypatch.setenv("COHERE_API_KEY", "test-api-key")
monkeypatch.setenv("CO_API_KEY", "fake-api-key")
Expand All @@ -125,7 +120,6 @@ def test_to_dict_with_parameters(self, monkeypatch):
},
}

@pytest.mark.unit
def test_to_dict_with_lambda_streaming_callback(self, monkeypatch):
monkeypatch.setenv("COHERE_API_KEY", "test-api-key")
component = CohereChatGenerator(
Expand All @@ -146,7 +140,6 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch):
},
}

@pytest.mark.unit
def test_from_dict(self, monkeypatch):
monkeypatch.setenv("COHERE_API_KEY", "fake-api-key")
monkeypatch.setenv("CO_API_KEY", "fake-api-key")
Expand All @@ -166,7 +159,6 @@ def test_from_dict(self, monkeypatch):
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}

@pytest.mark.unit
def test_from_dict_fail_wo_env_var(self, monkeypatch):
monkeypatch.delenv("COHERE_API_KEY", raising=False)
monkeypatch.delenv("CO_API_KEY", raising=False)
Expand All @@ -183,7 +175,6 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
with pytest.raises(ValueError):
CohereChatGenerator.from_dict(data)

@pytest.mark.unit
def test_run(self, chat_messages, mock_chat_response): # noqa: ARG002
component = CohereChatGenerator(api_key=Secret.from_token("test-api-key"))
response = component.run(chat_messages)
Expand All @@ -195,13 +186,11 @@ def test_run(self, chat_messages, mock_chat_response): # noqa: ARG002
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]

@pytest.mark.unit
def test_message_to_dict(self, chat_messages):
obj = CohereChatGenerator(api_key=Secret.from_token("test-api-key"))
dictionary = [obj._message_to_dict(message) for message in chat_messages]
assert dictionary == [{"user_name": "Chatbot", "text": "What's the capital of France"}]

@pytest.mark.unit
def test_run_with_params(self, chat_messages, mock_chat_response):
component = CohereChatGenerator(
api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5}
Expand All @@ -220,7 +209,6 @@ def test_run_with_params(self, chat_messages, mock_chat_response):
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]

@pytest.mark.unit
def test_run_streaming(self, chat_messages, mock_chat_response):
streaming_call_count = 0

Expand Down
1 change: 1 addition & 0 deletions integrations/deepeval/tests/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def test_evaluator_outputs(metric, inputs, expected_outputs, metric_params, monk
# OpenAI API. It is parameterized by the metric, the inputs to the evalutor
# and the metric parameters.
@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set")
@pytest.mark.integration
@pytest.mark.parametrize(
"metric, inputs, metric_params",
[
Expand Down
52 changes: 34 additions & 18 deletions integrations/mongodb_atlas/tests/test_retriever.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock
from unittest.mock import MagicMock, Mock, patch

import pytest
from haystack.dataclasses import Document
Expand All @@ -10,34 +10,48 @@
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore


@pytest.fixture
def document_store():
store = MongoDBAtlasDocumentStore(
database_name="haystack_integration_test",
collection_name="test_embeddings_collection",
vector_search_index="cosine_index",
)
return store
class TestRetriever:

@pytest.fixture
def mock_client(self):
with patch(
"haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient"
) as mock_mongo_client:
mock_connection = MagicMock()
mock_database = MagicMock()
mock_collection_names = MagicMock(return_value=["test_embeddings_collection"])
mock_database.list_collection_names = mock_collection_names
mock_connection.__getitem__.return_value = mock_database
mock_mongo_client.return_value = mock_connection
yield mock_mongo_client

class TestRetriever:
def test_init_default(self, document_store: MongoDBAtlasDocumentStore):
retriever = MongoDBAtlasEmbeddingRetriever(document_store=document_store)
assert retriever.document_store == document_store
def test_init_default(self):
mock_store = Mock(spec=MongoDBAtlasDocumentStore)
retriever = MongoDBAtlasEmbeddingRetriever(document_store=mock_store)
assert retriever.document_store == mock_store
assert retriever.filters == {}
assert retriever.top_k == 10

def test_init(self, document_store: MongoDBAtlasDocumentStore):
def test_init(self):
mock_store = Mock(spec=MongoDBAtlasDocumentStore)
retriever = MongoDBAtlasEmbeddingRetriever(
document_store=document_store,
document_store=mock_store,
filters={"field": "value"},
top_k=5,
)
assert retriever.document_store == document_store
assert retriever.document_store == mock_store
assert retriever.filters == {"field": "value"}
assert retriever.top_k == 5

def test_to_dict(self, document_store: MongoDBAtlasDocumentStore):
def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required
monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str")

document_store = MongoDBAtlasDocumentStore(
database_name="haystack_integration_test",
collection_name="test_embeddings_collection",
vector_search_index="cosine_index",
)

retriever = MongoDBAtlasEmbeddingRetriever(document_store=document_store, filters={"field": "value"}, top_k=5)
res = retriever.to_dict()
assert res == {
Expand All @@ -61,7 +75,9 @@ def test_to_dict(self, document_store: MongoDBAtlasDocumentStore):
},
}

def test_from_dict(self):
def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required
monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str")

data = {
"type": "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever", # noqa: E501
"init_parameters": {
Expand Down

0 comments on commit 27212f5

Please sign in to comment.