From 4bfc290b67a3bf92b7f5d86a195506723aa9664d Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 5 Mar 2024 13:19:50 +0100 Subject: [PATCH 1/8] wip --- .../mongodb_atlas/document_store.py | 10 +- .../document_stores/mongodb_atlas/filters.py | 193 +++++++++++++++++- .../mongodb_atlas/tests/test_filters.py | 48 +++++ 3 files changed, 241 insertions(+), 10 deletions(-) create mode 100644 integrations/mongodb_atlas/tests/test_filters.py diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index e2f2534f5..a56081488 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -10,7 +10,9 @@ from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace -from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo +from haystack_integrations.document_stores.mongodb_atlas.filters import _normalize_filters +from haystack.utils.filters import convert + from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne # type: ignore from pymongo.driver_info import DriverInfo # type: ignore from pymongo.errors import BulkWriteError # type: ignore @@ -100,8 +102,10 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :param filters: The filters to apply. It returns only the documents that match the filters. :return: A list of Documents that match the given filters. """ - mongo_filters = haystack_filters_to_mongo(filters) - documents = list(self.collection.find(mongo_filters)) + if filters and "operator" not in filters and "conditions" not in filters: + filters = convert(filters) + filters = _normalize_filters(filters) if filters else None + documents = list(self.collection.find(filters)) for doc in documents: doc.pop("_id", None) # MongoDB's internal id doesn't belong into a Haystack document, so we remove it. return [Document.from_dict(doc) for doc in documents] diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py index f03ca88c0..68141bace 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py @@ -1,9 +1,188 @@ -from typing import Any, Dict, Optional +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict +from haystack.errors import FilterError +from pandas import DataFrame + + +def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]: + """ + Converts Haystack filters in Pinecone compatible filters. + Reference: https://docs.pinecone.io/docs/metadata-filtering + """ + if not isinstance(filters, dict): + msg = "Filters must be a dictionary" + raise FilterError(msg) + + if "field" in filters: + return _parse_comparison_condition(filters) + return _parse_logical_condition(filters) + + +def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "conditions" not in condition: + msg = f"'conditions' key missing in {condition}" + raise FilterError(msg) + + operator = condition["operator"] + conditions = [_parse_comparison_condition(c) for c in condition["conditions"]] + + if operator in LOGICAL_OPERATORS: + return {LOGICAL_OPERATORS[operator]: conditions} + + msg = f"Unknown logical operator '{operator}'" + raise FilterError(msg) + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: + if "field" not in condition: + # 'field' key is only found in comparison dictionaries. + # We assume this is a logic dictionary since it's not present. + return _parse_logical_condition(condition) + + field: str = condition["field"] + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "value" not in condition: + msg = f"'value' key missing in {condition}" + raise FilterError(msg) + operator: str = condition["operator"] + if field.startswith("meta."): + # Remove the "meta." prefix if present. + # Documents are flattened when using the PineconeDocumentStore + # so we don't need to specify the "meta." prefix. + # Instead of raising an error we handle it gracefully. + field = field[5:] + + value: Any = condition["value"] + if isinstance(value, DataFrame): + value = value.to_json() + + return COMPARISON_OPERATORS[operator](field, value) + + +def _equal(field: str, value: Any) -> Dict[str, Any]: + supported_types = (str, int, float, bool, type(None)) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'equal' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$eq": value}} + + +def _not_equal(field: str, value: Any) -> Dict[str, Any]: + supported_types = (str, int, float, bool, type(None)) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'not equal' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$ne": value}} + + +def _greater_than(field: str, value: Any) -> Dict[str, Any]: + supported_types = (int, float, type(None)) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'greater than' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$gt": value}} + + +def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: + supported_types = (int, float) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'greater than equal' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$gte": value}} + + +def _less_than(field: str, value: Any) -> Dict[str, Any]: + supported_types = (int, float) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'less than' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$lt": value}} + + +def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: + supported_types = (int, float) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'less than equal' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$lte": value}} + + +def _not_in(field: str, value: Any) -> Dict[str, Any]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" + raise FilterError(msg) + + supported_types = (int, float, str) + for v in value: + if not isinstance(v, supported_types): + msg = ( + f"Unsupported type for 'not in' comparison: {type(v)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$nin": value}} + + +def _in(field: str, value: Any) -> Dict[str, Any]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone" + raise FilterError(msg) + + supported_types = (int, float, str) + for v in value: + if not isinstance(v, supported_types): + msg = ( + f"Unsupported type for 'in' comparison: {type(v)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$in": value}} + + +COMPARISON_OPERATORS = { + "==": _equal, + "!=": _not_equal, + ">": _greater_than, + ">=": _greater_than_equal, + "<": _less_than, + "<=": _less_than_equal, + "in": _in, + "not in": _not_in, +} + +LOGICAL_OPERATORS = {"AND": "$and", "OR": "$or"} -def haystack_filters_to_mongo(filters: Optional[Dict[str, Any]]): - # TODO - if filters: - msg = "Filtering not yet implemented for MongoDBAtlasDocumentStore" - raise ValueError(msg) - return {} diff --git a/integrations/mongodb_atlas/tests/test_filters.py b/integrations/mongodb_atlas/tests/test_filters.py new file mode 100644 index 000000000..56dd0ae7c --- /dev/null +++ b/integrations/mongodb_atlas/tests/test_filters.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from uuid import uuid4 + +import pytest +from haystack.dataclasses.document import ByteStream, Document +from haystack.document_stores.errors import DuplicateDocumentError +from haystack.document_stores.types import DuplicatePolicy +from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest, FilterDocumentsTest +from haystack.utils import Secret +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore +from pandas import DataFrame +from pymongo import MongoClient # type: ignore +from pymongo.driver_info import DriverInfo # type: ignore +import pandas as pd + + +@pytest.fixture +def document_store(): + database_name = "haystack_integration_test" + collection_name = "test_collection_" + str(uuid4()) + + connection: MongoClient = MongoClient( + os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") + ) + database = connection[database_name] + if collection_name in database.list_collection_names(): + database[collection_name].drop() + database.create_collection(collection_name) + database[collection_name].create_index("id", unique=True) + + store = MongoDBAtlasDocumentStore( + database_name=database_name, + collection_name=collection_name, + vector_search_index="cosine_index", + ) + yield store + database[collection_name].drop() + + +@pytest.mark.skipif( + "MONGO_CONNECTION_STRING" not in os.environ, + reason="No MongoDB Atlas connection string provided", +) +class TestFilters(FilterDocumentsTest): + pass \ No newline at end of file From 7c78b79ac0f8ae728d3edc20015f2e14a999656b Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 6 Mar 2024 08:50:52 +0100 Subject: [PATCH 2/8] progress --- .../mongodb_atlas/document_store.py | 15 +- .../document_stores/mongodb_atlas/filters.py | 159 +++++++++--------- .../tests/test_embedding_retrieval.py | 15 ++ .../mongodb_atlas/tests/test_filters.py | 34 +++- 4 files changed, 137 insertions(+), 86 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index a56081488..969a04574 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -11,7 +11,6 @@ from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.document_stores.mongodb_atlas.filters import _normalize_filters -from haystack.utils.filters import convert from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne # type: ignore from pymongo.driver_info import DriverInfo # type: ignore @@ -102,9 +101,8 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :param filters: The filters to apply. It returns only the documents that match the filters. :return: A list of Documents that match the given filters. """ - if filters and "operator" not in filters and "conditions" not in filters: - filters = convert(filters) filters = _normalize_filters(filters) if filters else None + print(filters) documents = list(self.collection.find(filters)) for doc in documents: doc.pop("_id", None) # MongoDB's internal id doesn't belong into a Haystack document, so we remove it. @@ -129,7 +127,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D if policy == DuplicatePolicy.NONE: policy = DuplicatePolicy.FAIL - mongo_documents = [doc.to_dict() for doc in documents] + mongo_documents = [doc.to_dict(flatten=False) for doc in documents] operations: List[Union[UpdateOne, InsertOne, ReplaceOne]] written_docs = len(documents) @@ -177,7 +175,8 @@ def embedding_retrieval( msg = "Query embedding must not be empty" raise ValueError(msg) - filters = haystack_filters_to_mongo(filters) + filters = _normalize_filters(filters) if filters else None + pipeline = [ { "$vectorSearch": { @@ -186,7 +185,7 @@ def embedding_retrieval( "queryVector": query_embedding, "numCandidates": 100, "limit": top_k, - # "filter": filters, + "filter": filters, } }, { @@ -207,10 +206,10 @@ def embedding_retrieval( msg = f"Retrieval of documents from MongoDB Atlas failed: {e}" raise DocumentStoreError(msg) from e - documents = [self.mongo_doc_to_haystack_doc(doc) for doc in documents] + documents = [self._mongo_doc_to_haystack_doc(doc) for doc in documents] return documents - def mongo_doc_to_haystack_doc(self, mongo_doc: Dict[str, Any]) -> Document: + def _mongo_doc_to_haystack_doc(self, mongo_doc: Dict[str, Any]) -> Document: """ Converts the dictionary coming out of MongoDB into a Haystack document diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py index 68141bace..4bcfb8558 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py @@ -1,20 +1,26 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH +# SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 from typing import Any, Dict from haystack.errors import FilterError from pandas import DataFrame +from datetime import datetime +from haystack.utils.filters import convert +UNSUPPORTED_TYPES_FOR_COMPARISON = (list, DataFrame) + def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]: """ - Converts Haystack filters in Pinecone compatible filters. - Reference: https://docs.pinecone.io/docs/metadata-filtering - """ + Converts Haystack filters to MongoDB filters. + """ if not isinstance(filters, dict): msg = "Filters must be a dictionary" raise FilterError(msg) + + if "operator" not in filters and "conditions" not in filters: + filters = convert(filters) if "field" in filters: return _parse_comparison_condition(filters) @@ -30,21 +36,28 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: raise FilterError(msg) operator = condition["operator"] - conditions = [_parse_comparison_condition(c) for c in condition["conditions"]] - - if operator in LOGICAL_OPERATORS: - return {LOGICAL_OPERATORS[operator]: conditions} - - msg = f"Unknown logical operator '{operator}'" - raise FilterError(msg) - + if operator not in ["AND", "OR", "NOT"]: + msg = f"Unknown logical operator '{operator}'. Valid operators are: 'AND', 'OR', 'NOT'" + raise FilterError(msg) + + # logical conditions can be nested, so we need to parse them recursively + conditions = [] + for c in condition["conditions"]: + if "field" in c: + conditions.append(_parse_comparison_condition(c)) + else: + conditions.append(_parse_logical_condition(c)) + + if operator == "AND": + return {"$and": conditions} + elif operator == "OR": + return {"$or": conditions} + elif operator == "NOT": + # MongoDB doesn't support our NOT operator (logical NAND) directly. + # we combine $nor and $and to achieve the same effect. + return {"$nor": [{"$and": conditions}]} def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: - if "field" not in condition: - # 'field' key is only found in comparison dictionaries. - # We assume this is a logic dictionary since it's not present. - return _parse_logical_condition(condition) - field: str = condition["field"] if "operator" not in condition: msg = f"'operator' key missing in {condition}" @@ -53,14 +66,8 @@ def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: msg = f"'value' key missing in {condition}" raise FilterError(msg) operator: str = condition["operator"] - if field.startswith("meta."): - # Remove the "meta." prefix if present. - # Documents are flattened when using the PineconeDocumentStore - # so we don't need to specify the "meta." prefix. - # Instead of raising an error we handle it gracefully. - field = field[5:] - value: Any = condition["value"] + if isinstance(value, DataFrame): value = value.to_json() @@ -68,73 +75,91 @@ def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: def _equal(field: str, value: Any) -> Dict[str, Any]: - supported_types = (str, int, float, bool, type(None)) - if not isinstance(value, supported_types): - msg = ( - f"Unsupported type for 'equal' comparison: {type(value)}. " - f"Types supported by Pinecone are: {supported_types}" - ) - raise FilterError(msg) - return {field: {"$eq": value}} - def _not_equal(field: str, value: Any) -> Dict[str, Any]: - supported_types = (str, int, float, bool, type(None)) - if not isinstance(value, supported_types): - msg = ( - f"Unsupported type for 'not equal' comparison: {type(value)}. " - f"Types supported by Pinecone are: {supported_types}" - ) - raise FilterError(msg) - return {field: {"$ne": value}} - def _greater_than(field: str, value: Any) -> Dict[str, Any]: - supported_types = (int, float, type(None)) - if not isinstance(value, supported_types): + if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON): msg = ( - f"Unsupported type for 'greater than' comparison: {type(value)}. " - f"Types supported by Pinecone are: {supported_types}" + f"Unsupported type for '>' comparison: {type(value)}. " ) raise FilterError(msg) + elif isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc return {field: {"$gt": value}} def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: - supported_types = (int, float) - if not isinstance(value, supported_types): + if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON): msg = ( - f"Unsupported type for 'greater than equal' comparison: {type(value)}. " - f"Types supported by Pinecone are: {supported_types}" + f"Unsupported type for '>=' comparison: {type(value)}. " ) raise FilterError(msg) + elif isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + elif value is None: + # we want {field: {"$gte": null}} to return an empty result + # $gte with null values in MongoDB returns a non-empty result, while $gt aligns with our expectations + return {field: {"$gt": value}} return {field: {"$gte": value}} def _less_than(field: str, value: Any) -> Dict[str, Any]: - supported_types = (int, float) - if not isinstance(value, supported_types): + if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON): msg = ( - f"Unsupported type for 'less than' comparison: {type(value)}. " - f"Types supported by Pinecone are: {supported_types}" + f"Unsupported type for '<' comparison: {type(value)}. " ) raise FilterError(msg) + elif isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc return {field: {"$lt": value}} def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: - supported_types = (int, float) - if not isinstance(value, supported_types): + if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON): msg = ( f"Unsupported type for 'less than equal' comparison: {type(value)}. " - f"Types supported by Pinecone are: {supported_types}" ) raise FilterError(msg) + elif isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + elif value is None: + # we want {field: {"$lte": null}} to return an empty result + # $lte with null values in MongoDB returns a non-empty result, while $lt aligns with our expectations + return {field: {"$lt": value}} return {field: {"$lte": value}} @@ -144,15 +169,6 @@ def _not_in(field: str, value: Any) -> Dict[str, Any]: msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" raise FilterError(msg) - supported_types = (int, float, str) - for v in value: - if not isinstance(v, supported_types): - msg = ( - f"Unsupported type for 'not in' comparison: {type(v)}. " - f"Types supported by Pinecone are: {supported_types}" - ) - raise FilterError(msg) - return {field: {"$nin": value}} @@ -161,15 +177,6 @@ def _in(field: str, value: Any) -> Dict[str, Any]: msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone" raise FilterError(msg) - supported_types = (int, float, str) - for v in value: - if not isinstance(v, supported_types): - msg = ( - f"Unsupported type for 'in' comparison: {type(v)}. " - f"Types supported by Pinecone are: {supported_types}" - ) - raise FilterError(msg) - return {field: {"$in": value}} @@ -184,5 +191,3 @@ def _in(field: str, value: Any) -> Dict[str, Any]: "not in": _not_in, } -LOGICAL_OPERATORS = {"AND": "$and", "OR": "$or"} - diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py index aa7790bc7..ec689bc0a 100644 --- a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -7,6 +7,7 @@ import pytest from haystack.document_stores.errors import DocumentStoreError from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore +from haystack.dataclasses.document import Document @pytest.mark.skipif( @@ -72,3 +73,17 @@ def test_query_embedding_wrong_dimension(self): query_embedding = [0.1] * 4 with pytest.raises(DocumentStoreError): document_store.embedding_retrieval(query_embedding=query_embedding) + + def test_embedding_retrieval_with_filters(self): + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="fake", + vector_search_index="cosine_index", + ) + print(document_store.filter_documents()) + query_embedding = [0.1] * 768 + # document_store.write_documents([Document(content="Document B", embedding=query_embedding, meta={"category": "test"})]) + filters = {"field": "content", "operator": "!=", "value": "Document A"} + results = document_store.embedding_retrieval(query_embedding=query_embedding, top_k=100, filters=filters) + assert len(results) == 1 + assert results[0].content == "Document B" diff --git a/integrations/mongodb_atlas/tests/test_filters.py b/integrations/mongodb_atlas/tests/test_filters.py index 56dd0ae7c..6eb576005 100644 --- a/integrations/mongodb_atlas/tests/test_filters.py +++ b/integrations/mongodb_atlas/tests/test_filters.py @@ -45,4 +45,36 @@ def document_store(): reason="No MongoDB Atlas connection string provided", ) class TestFilters(FilterDocumentsTest): - pass \ No newline at end of file + def test_complex_filter(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + filters = { + "operator": "OR", + "conditions": [ + { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + }, + { + "operator": "AND", + "conditions": [ + {"field": "meta.page", "operator": "==", "value": "90"}, + {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, + ], + }, + ], + } + + result = document_store.filter_documents(filters=filters) + + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if (d.meta.get("number") == 100 and d.meta.get("chapter") == "intro") + or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion") + ], + ) \ No newline at end of file From 20e14f6dc70b06cde543ec7e404772207ba4d343 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 6 Mar 2024 10:27:40 +0100 Subject: [PATCH 3/8] more tests --- .../mongodb_atlas/tests/test_embedding_retrieval.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py index ec689bc0a..d0ed282ac 100644 --- a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -77,13 +77,12 @@ def test_query_embedding_wrong_dimension(self): def test_embedding_retrieval_with_filters(self): document_store = MongoDBAtlasDocumentStore( database_name="haystack_integration_test", - collection_name="fake", + collection_name="test_embeddings_collection", vector_search_index="cosine_index", ) - print(document_store.filter_documents()) query_embedding = [0.1] * 768 - # document_store.write_documents([Document(content="Document B", embedding=query_embedding, meta={"category": "test"})]) filters = {"field": "content", "operator": "!=", "value": "Document A"} - results = document_store.embedding_retrieval(query_embedding=query_embedding, top_k=100, filters=filters) - assert len(results) == 1 - assert results[0].content == "Document B" + results = document_store.embedding_retrieval(query_embedding=query_embedding, top_k=2, filters=filters) + assert len(results) == 2 + for doc in results: + assert doc.content != "Document A" From d87769ffec6d2c202b19cb7d556fd98f971fc1fe Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 6 Mar 2024 11:05:13 +0100 Subject: [PATCH 4/8] improvements --- .../mongodb_atlas/embedding_retriever.py | 4 +- .../mongodb_atlas/document_store.py | 7 +- .../document_stores/mongodb_atlas/errors.py | 4 - .../document_stores/mongodb_atlas/filters.py | 34 +++----- .../tests/test_document_store.py | 85 +++++++++++++------ .../tests/test_embedding_retrieval.py | 20 ++++- .../mongodb_atlas/tests/test_filters.py | 80 ----------------- 7 files changed, 101 insertions(+), 133 deletions(-) delete mode 100644 integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py delete mode 100644 integrations/mongodb_atlas/tests/test_filters.py diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index 432b86d4c..ffad97789 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -48,7 +48,9 @@ def __init__( Create the MongoDBAtlasDocumentStore component. :param document_store: An instance of MongoDBAtlasDocumentStore. - :param filters: Filters applied to the retrieved Documents. + :param filters: Filters applied to the retrieved Documents. Make sure that the fields used in the filters are + included in the configuration of the `vector_search_index`. The configuration must be done manually + in the Web UI of MongoDB Atlas. :param top_k: Maximum number of Documents to return. :raises ValueError: If `document_store` is not an instance of `MongoDBAtlasDocumentStore`. diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index dd848aaea..976bbc8c4 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -11,7 +11,6 @@ from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.document_stores.mongodb_atlas.filters import _normalize_filters - from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne # type: ignore from pymongo.driver_info import DriverInfo # type: ignore from pymongo.errors import BulkWriteError # type: ignore @@ -146,7 +145,6 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :returns: A list of Documents that match the given filters. """ filters = _normalize_filters(filters) if filters else None - print(filters) documents = list(self.collection.find(filters)) for doc in documents: doc.pop("_id", None) # MongoDB's internal id doesn't belong into a Haystack document, so we remove it. @@ -252,6 +250,11 @@ def _embedding_retrieval( documents = list(self.collection.aggregate(pipeline)) except Exception as e: msg = f"Retrieval of documents from MongoDB Atlas failed: {e}" + if filters: + msg += ( + "\nMake sure that the fields used in the filters are included " + "in the `vector_search_index` configuration" + ) raise DocumentStoreError(msg) from e documents = [self._mongo_doc_to_haystack_doc(doc) for doc in documents] diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py deleted file mode 100644 index 132156bd0..000000000 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py +++ /dev/null @@ -1,4 +0,0 @@ -class MongoDBAtlasDocumentStoreError(Exception): - """Exception for issues that occur in a MongoDBAtlas document store""" - - pass diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py index 4bcfb8558..46bc32b81 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py @@ -1,24 +1,24 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from datetime import datetime from typing import Any, Dict from haystack.errors import FilterError -from pandas import DataFrame -from datetime import datetime from haystack.utils.filters import convert - +from pandas import DataFrame UNSUPPORTED_TYPES_FOR_COMPARISON = (list, DataFrame) + def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]: """ Converts Haystack filters to MongoDB filters. - """ + """ if not isinstance(filters, dict): msg = "Filters must be a dictionary" raise FilterError(msg) - + if "operator" not in filters and "conditions" not in filters: filters = convert(filters) @@ -39,7 +39,7 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: if operator not in ["AND", "OR", "NOT"]: msg = f"Unknown logical operator '{operator}'. Valid operators are: 'AND', 'OR', 'NOT'" raise FilterError(msg) - + # logical conditions can be nested, so we need to parse them recursively conditions = [] for c in condition["conditions"]: @@ -57,6 +57,7 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: # we combine $nor and $and to achieve the same effect. return {"$nor": [{"$and": conditions}]} + def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: field: str = condition["field"] if "operator" not in condition: @@ -77,14 +78,14 @@ def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: def _equal(field: str, value: Any) -> Dict[str, Any]: return {field: {"$eq": value}} + def _not_equal(field: str, value: Any) -> Dict[str, Any]: return {field: {"$ne": value}} + def _greater_than(field: str, value: Any) -> Dict[str, Any]: if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON): - msg = ( - f"Unsupported type for '>' comparison: {type(value)}. " - ) + msg = f"Unsupported type for '>' comparison: {type(value)}. " raise FilterError(msg) elif isinstance(value, str): try: @@ -94,16 +95,14 @@ def _greater_than(field: str, value: Any) -> Dict[str, Any]: "Can't compare strings using operators '>', '>=', '<', '<='. " "Strings are only comparable if they are ISO formatted dates." ) - raise FilterError(msg) from exc + raise FilterError(msg) from exc return {field: {"$gt": value}} def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON): - msg = ( - f"Unsupported type for '>=' comparison: {type(value)}. " - ) + msg = f"Unsupported type for '>=' comparison: {type(value)}. " raise FilterError(msg) elif isinstance(value, str): try: @@ -124,9 +123,7 @@ def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: def _less_than(field: str, value: Any) -> Dict[str, Any]: if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON): - msg = ( - f"Unsupported type for '<' comparison: {type(value)}. " - ) + msg = f"Unsupported type for '<' comparison: {type(value)}. " raise FilterError(msg) elif isinstance(value, str): try: @@ -143,9 +140,7 @@ def _less_than(field: str, value: Any) -> Dict[str, Any]: def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON): - msg = ( - f"Unsupported type for 'less than equal' comparison: {type(value)}. " - ) + msg = f"Unsupported type for 'less than equal' comparison: {type(value)}. " raise FilterError(msg) elif isinstance(value, str): try: @@ -190,4 +185,3 @@ def _in(field: str, value: Any) -> Dict[str, Any]: "in": _in, "not in": _not_in, } - diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 39a4465c1..dcbd3cf84 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -8,7 +8,7 @@ from haystack.dataclasses.document import ByteStream, Document from haystack.document_stores.errors import DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy -from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest +from haystack.testing.document_store import DocumentStoreBaseTests from haystack.utils import Secret from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore from pandas import DataFrame @@ -16,34 +16,35 @@ from pymongo.driver_info import DriverInfo # type: ignore -@pytest.fixture -def document_store(): - database_name = "haystack_integration_test" - collection_name = "test_collection_" + str(uuid4()) - - connection: MongoClient = MongoClient( - os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") - ) - database = connection[database_name] - if collection_name in database.list_collection_names(): - database[collection_name].drop() - database.create_collection(collection_name) - database[collection_name].create_index("id", unique=True) - - store = MongoDBAtlasDocumentStore( - database_name=database_name, - collection_name=collection_name, - vector_search_index="cosine_index", - ) - yield store - database[collection_name].drop() - - @pytest.mark.skipif( "MONGO_CONNECTION_STRING" not in os.environ, reason="No MongoDB Atlas connection string provided", ) -class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): +@pytest.mark.integration +class TestDocumentStore(DocumentStoreBaseTests): + + @pytest.fixture + def document_store(self): + database_name = "haystack_integration_test" + collection_name = "test_collection_" + str(uuid4()) + + connection: MongoClient = MongoClient( + os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") + ) + database = connection[database_name] + if collection_name in database.list_collection_names(): + database[collection_name].drop() + database.create_collection(collection_name) + database[collection_name].create_index("id", unique=True) + + store = MongoDBAtlasDocumentStore( + database_name=database_name, + collection_name=collection_name, + vector_search_index="cosine_index", + ) + yield store + database[collection_name].drop() + def test_write_documents(self, document_store: MongoDBAtlasDocumentStore): docs = [Document(content="some text")] assert document_store.write_documents(docs) == 1 @@ -104,3 +105,37 @@ def test_from_dict(self): assert docstore.database_name == "haystack_integration_test" assert docstore.collection_name == "test_embeddings_collection" assert docstore.vector_search_index == "cosine_index" + + def test_complex_filter(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + filters = { + "operator": "OR", + "conditions": [ + { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + }, + { + "operator": "AND", + "conditions": [ + {"field": "meta.page", "operator": "==", "value": "90"}, + {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, + ], + }, + ], + } + + result = document_store.filter_documents(filters=filters) + + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if (d.meta.get("number") == 100 and d.meta.get("chapter") == "intro") + or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion") + ], + ) diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py index 0fe693c25..a03c735e0 100644 --- a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -7,13 +7,13 @@ import pytest from haystack.document_stores.errors import DocumentStoreError from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore -from haystack.dataclasses.document import Document @pytest.mark.skipif( "MONGO_CONNECTION_STRING" not in os.environ, reason="No MongoDB Atlas connection string provided", ) +@pytest.mark.integration class TestEmbeddingRetrieval: def test_embedding_retrieval_cosine_similarity(self): document_store = MongoDBAtlasDocumentStore( @@ -75,6 +75,24 @@ def test_query_embedding_wrong_dimension(self): document_store._embedding_retrieval(query_embedding=query_embedding) def test_embedding_retrieval_with_filters(self): + """ + Note: we can combine embedding retrieval with filters + becuse the `cosine_index` vector_search_index was created with the `content` field as the filter field. + { + "fields": [ + { + "type": "vector", + "path": "embedding", + "numDimensions": 768, + "similarity": "cosine" + }, + { + "type": "filter", + "path": "content" + } + ] + } + """ document_store = MongoDBAtlasDocumentStore( database_name="haystack_integration_test", collection_name="test_embeddings_collection", diff --git a/integrations/mongodb_atlas/tests/test_filters.py b/integrations/mongodb_atlas/tests/test_filters.py deleted file mode 100644 index 6eb576005..000000000 --- a/integrations/mongodb_atlas/tests/test_filters.py +++ /dev/null @@ -1,80 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -import os -from uuid import uuid4 - -import pytest -from haystack.dataclasses.document import ByteStream, Document -from haystack.document_stores.errors import DuplicateDocumentError -from haystack.document_stores.types import DuplicatePolicy -from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest, FilterDocumentsTest -from haystack.utils import Secret -from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore -from pandas import DataFrame -from pymongo import MongoClient # type: ignore -from pymongo.driver_info import DriverInfo # type: ignore -import pandas as pd - - -@pytest.fixture -def document_store(): - database_name = "haystack_integration_test" - collection_name = "test_collection_" + str(uuid4()) - - connection: MongoClient = MongoClient( - os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") - ) - database = connection[database_name] - if collection_name in database.list_collection_names(): - database[collection_name].drop() - database.create_collection(collection_name) - database[collection_name].create_index("id", unique=True) - - store = MongoDBAtlasDocumentStore( - database_name=database_name, - collection_name=collection_name, - vector_search_index="cosine_index", - ) - yield store - database[collection_name].drop() - - -@pytest.mark.skipif( - "MONGO_CONNECTION_STRING" not in os.environ, - reason="No MongoDB Atlas connection string provided", -) -class TestFilters(FilterDocumentsTest): - def test_complex_filter(self, document_store, filterable_docs): - document_store.write_documents(filterable_docs) - filters = { - "operator": "OR", - "conditions": [ - { - "operator": "AND", - "conditions": [ - {"field": "meta.number", "operator": "==", "value": 100}, - {"field": "meta.chapter", "operator": "==", "value": "intro"}, - ], - }, - { - "operator": "AND", - "conditions": [ - {"field": "meta.page", "operator": "==", "value": "90"}, - {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, - ], - }, - ], - } - - result = document_store.filter_documents(filters=filters) - - self.assert_documents_are_equal( - result, - [ - d - for d in filterable_docs - if (d.meta.get("number") == 100 and d.meta.get("chapter") == "intro") - or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion") - ], - ) \ No newline at end of file From 3520635f8438473b9776d831d13b1d7262a41138 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 6 Mar 2024 11:09:38 +0100 Subject: [PATCH 5/8] ignore missing imports in pyproject --- integrations/mongodb_atlas/pyproject.toml | 3 +-- .../document_stores/mongodb_atlas/document_store.py | 6 +++--- integrations/mongodb_atlas/tests/test_document_store.py | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/integrations/mongodb_atlas/pyproject.toml b/integrations/mongodb_atlas/pyproject.toml index 0021884ad..b3eba42e2 100644 --- a/integrations/mongodb_atlas/pyproject.toml +++ b/integrations/mongodb_atlas/pyproject.toml @@ -175,8 +175,7 @@ exclude_lines = [ module = [ "haystack.*", "haystack_integrations.*", - "mongodb_atlas.*", - "psycopg.*", + "pymongo.*", "pytest.*" ] ignore_missing_imports = true diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 976bbc8c4..c9e8f1dae 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -11,9 +11,9 @@ from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.document_stores.mongodb_atlas.filters import _normalize_filters -from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne # type: ignore -from pymongo.driver_info import DriverInfo # type: ignore -from pymongo.errors import BulkWriteError # type: ignore +from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne +from pymongo.driver_info import DriverInfo +from pymongo.errors import BulkWriteError logger = logging.getLogger(__name__) diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index dcbd3cf84..89810ec8b 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -12,8 +12,8 @@ from haystack.utils import Secret from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore from pandas import DataFrame -from pymongo import MongoClient # type: ignore -from pymongo.driver_info import DriverInfo # type: ignore +from pymongo import MongoClient +from pymongo.driver_info import DriverInfo @pytest.mark.skipif( From 460b2dbcb8d8af2a37c6e5e6bcef4caa6eda9ff5 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 6 Mar 2024 11:16:47 +0100 Subject: [PATCH 6/8] fix mypy --- .../document_stores/mongodb_atlas/filters.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py index 46bc32b81..f8a9425a7 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py @@ -35,11 +35,6 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: msg = f"'conditions' key missing in {condition}" raise FilterError(msg) - operator = condition["operator"] - if operator not in ["AND", "OR", "NOT"]: - msg = f"Unknown logical operator '{operator}'. Valid operators are: 'AND', 'OR', 'NOT'" - raise FilterError(msg) - # logical conditions can be nested, so we need to parse them recursively conditions = [] for c in condition["conditions"]: @@ -48,6 +43,7 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: else: conditions.append(_parse_logical_condition(c)) + operator = condition["operator"] if operator == "AND": return {"$and": conditions} elif operator == "OR": @@ -57,6 +53,9 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: # we combine $nor and $and to achieve the same effect. return {"$nor": [{"$and": conditions}]} + msg = f"Unknown logical operator '{operator}'. Valid operators are: 'AND', 'OR', 'NOT'" + raise FilterError(msg) + def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: field: str = condition["field"] From 8c21042f0d82ca0fa4833eac54f16fbc9b4df249 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 6 Mar 2024 11:25:00 +0100 Subject: [PATCH 7/8] show coverage --- integrations/mongodb_atlas/pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/integrations/mongodb_atlas/pyproject.toml b/integrations/mongodb_atlas/pyproject.toml index b3eba42e2..6e6b55dfe 100644 --- a/integrations/mongodb_atlas/pyproject.toml +++ b/integrations/mongodb_atlas/pyproject.toml @@ -156,21 +156,21 @@ ban-relative-imports = "parents" "examples/**/*" = ["T201"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -tests = ["tests", "*/mongodb-atlas-haystack/tests"] - [tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true exclude_lines = [ "no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + [[tool.mypy.overrides]] module = [ "haystack.*", From e21c5df857d3a8ba21a89487fe11c26e3b3350f7 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 8 Mar 2024 13:48:30 +0100 Subject: [PATCH 8/8] rm code duplication --- .../document_stores/mongodb_atlas/filters.py | 56 ++++--------------- 1 file changed, 11 insertions(+), 45 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py index f8a9425a7..4583d6cd3 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py @@ -82,78 +82,44 @@ def _not_equal(field: str, value: Any) -> Dict[str, Any]: return {field: {"$ne": value}} -def _greater_than(field: str, value: Any) -> Dict[str, Any]: +def _validate_type_for_comparison(value: Any) -> None: + msg = f"Cant compare {type(value)} using operators '>', '>=', '<', '<='." if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON): - msg = f"Unsupported type for '>' comparison: {type(value)}. " raise FilterError(msg) elif isinstance(value, str): try: datetime.fromisoformat(value) except (ValueError, TypeError) as exc: - msg = ( - "Can't compare strings using operators '>', '>=', '<', '<='. " - "Strings are only comparable if they are ISO formatted dates." - ) + msg += "\nStrings are only comparable if they are ISO formatted dates." raise FilterError(msg) from exc + +def _greater_than(field: str, value: Any) -> Dict[str, Any]: + _validate_type_for_comparison(value) return {field: {"$gt": value}} def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: - if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON): - msg = f"Unsupported type for '>=' comparison: {type(value)}. " - raise FilterError(msg) - elif isinstance(value, str): - try: - datetime.fromisoformat(value) - except (ValueError, TypeError) as exc: - msg = ( - "Can't compare strings using operators '>', '>=', '<', '<='. " - "Strings are only comparable if they are ISO formatted dates." - ) - raise FilterError(msg) from exc - elif value is None: + if value is None: # we want {field: {"$gte": null}} to return an empty result # $gte with null values in MongoDB returns a non-empty result, while $gt aligns with our expectations return {field: {"$gt": value}} + _validate_type_for_comparison(value) return {field: {"$gte": value}} def _less_than(field: str, value: Any) -> Dict[str, Any]: - if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON): - msg = f"Unsupported type for '<' comparison: {type(value)}. " - raise FilterError(msg) - elif isinstance(value, str): - try: - datetime.fromisoformat(value) - except (ValueError, TypeError) as exc: - msg = ( - "Can't compare strings using operators '>', '>=', '<', '<='. " - "Strings are only comparable if they are ISO formatted dates." - ) - raise FilterError(msg) from exc - + _validate_type_for_comparison(value) return {field: {"$lt": value}} def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: - if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON): - msg = f"Unsupported type for 'less than equal' comparison: {type(value)}. " - raise FilterError(msg) - elif isinstance(value, str): - try: - datetime.fromisoformat(value) - except (ValueError, TypeError) as exc: - msg = ( - "Can't compare strings using operators '>', '>=', '<', '<='. " - "Strings are only comparable if they are ISO formatted dates." - ) - raise FilterError(msg) from exc - elif value is None: + if value is None: # we want {field: {"$lte": null}} to return an empty result # $lte with null values in MongoDB returns a non-empty result, while $lt aligns with our expectations return {field: {"$lt": value}} + _validate_type_for_comparison(value) return {field: {"$lte": value}}