Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MongoDB Atlas: filters #542

Merged
merged 9 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions integrations/mongodb_atlas/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -156,27 +156,26 @@ ban-relative-imports = "parents"
"examples/**/*" = ["T201"]

[tool.coverage.run]
source_pkgs = ["src", "tests"]
source = ["haystack_integrations"]
branch = true
parallel = true
parallel = false


[tool.coverage.paths]
tests = ["tests", "*/mongodb-atlas-haystack/tests"]

[tool.coverage.report]
omit = ["*/tests/*", "*/__init__.py"]
show_missing=true
exclude_lines = [
"no cov",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
]


[[tool.mypy.overrides]]
module = [
"haystack.*",
"haystack_integrations.*",
"mongodb_atlas.*",
"psycopg.*",
"pymongo.*",
"pytest.*"
]
ignore_missing_imports = true
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def __init__(
Create the MongoDBAtlasDocumentStore component.

:param document_store: An instance of MongoDBAtlasDocumentStore.
:param filters: Filters applied to the retrieved Documents.
:param filters: Filters applied to the retrieved Documents. Make sure that the fields used in the filters are
included in the configuration of the `vector_search_index`. The configuration must be done manually
in the Web UI of MongoDB Atlas.
:param top_k: Maximum number of Documents to return.

:raises ValueError: If `document_store` is not an instance of `MongoDBAtlasDocumentStore`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo
from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne # type: ignore
from pymongo.driver_info import DriverInfo # type: ignore
from pymongo.errors import BulkWriteError # type: ignore
from haystack_integrations.document_stores.mongodb_atlas.filters import _normalize_filters
from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne
from pymongo.driver_info import DriverInfo
from pymongo.errors import BulkWriteError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -144,8 +144,8 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
:param filters: The filters to apply. It returns only the documents that match the filters.
:returns: A list of Documents that match the given filters.
"""
mongo_filters = haystack_filters_to_mongo(filters)
documents = list(self.collection.find(mongo_filters))
filters = _normalize_filters(filters) if filters else None
documents = list(self.collection.find(filters))
for doc in documents:
doc.pop("_id", None) # MongoDB's internal id doesn't belong into a Haystack document, so we remove it.
return [Document.from_dict(doc) for doc in documents]
Expand All @@ -170,7 +170,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
if policy == DuplicatePolicy.NONE:
policy = DuplicatePolicy.FAIL

mongo_documents = [doc.to_dict() for doc in documents]
mongo_documents = [doc.to_dict(flatten=False) for doc in documents]
operations: List[Union[UpdateOne, InsertOne, ReplaceOne]]
written_docs = len(documents)

Expand Down Expand Up @@ -221,7 +221,8 @@ def _embedding_retrieval(
msg = "Query embedding must not be empty"
raise ValueError(msg)

filters = haystack_filters_to_mongo(filters)
filters = _normalize_filters(filters) if filters else None

pipeline = [
{
"$vectorSearch": {
Expand All @@ -230,7 +231,7 @@ def _embedding_retrieval(
"queryVector": query_embedding,
"numCandidates": 100,
"limit": top_k,
# "filter": filters,
"filter": filters,
}
},
{
Expand All @@ -249,6 +250,11 @@ def _embedding_retrieval(
documents = list(self.collection.aggregate(pipeline))
except Exception as e:
msg = f"Retrieval of documents from MongoDB Atlas failed: {e}"
if filters:
msg += (
"\nMake sure that the fields used in the filters are included "
"in the `vector_search_index` configuration"
)
raise DocumentStoreError(msg) from e

documents = [self._mongo_doc_to_haystack_doc(doc) for doc in documents]
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,9 +1,186 @@
from typing import Any, Dict, Optional
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
from typing import Any, Dict

from haystack.errors import FilterError
from haystack.utils.filters import convert
from pandas import DataFrame

def haystack_filters_to_mongo(filters: Optional[Dict[str, Any]]):
# TODO
if filters:
msg = "Filtering not yet implemented for MongoDBAtlasDocumentStore"
raise ValueError(msg)
return {}
UNSUPPORTED_TYPES_FOR_COMPARISON = (list, DataFrame)


def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]:
"""
Converts Haystack filters to MongoDB filters.
"""
if not isinstance(filters, dict):
msg = "Filters must be a dictionary"
raise FilterError(msg)

if "operator" not in filters and "conditions" not in filters:
filters = convert(filters)

if "field" in filters:
return _parse_comparison_condition(filters)
return _parse_logical_condition(filters)


def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "conditions" not in condition:
msg = f"'conditions' key missing in {condition}"
raise FilterError(msg)

# logical conditions can be nested, so we need to parse them recursively
conditions = []
for c in condition["conditions"]:
if "field" in c:
conditions.append(_parse_comparison_condition(c))
else:
conditions.append(_parse_logical_condition(c))

operator = condition["operator"]
if operator == "AND":
return {"$and": conditions}
elif operator == "OR":
return {"$or": conditions}
elif operator == "NOT":
# MongoDB doesn't support our NOT operator (logical NAND) directly.
# we combine $nor and $and to achieve the same effect.
return {"$nor": [{"$and": conditions}]}

msg = f"Unknown logical operator '{operator}'. Valid operators are: 'AND', 'OR', 'NOT'"
raise FilterError(msg)


def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
field: str = condition["field"]
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "value" not in condition:
msg = f"'value' key missing in {condition}"
raise FilterError(msg)
operator: str = condition["operator"]
value: Any = condition["value"]

if isinstance(value, DataFrame):
value = value.to_json()

return COMPARISON_OPERATORS[operator](field, value)


def _equal(field: str, value: Any) -> Dict[str, Any]:
return {field: {"$eq": value}}


def _not_equal(field: str, value: Any) -> Dict[str, Any]:
return {field: {"$ne": value}}


def _greater_than(field: str, value: Any) -> Dict[str, Any]:
if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON):
msg = f"Unsupported type for '>' comparison: {type(value)}. "
raise FilterError(msg)
elif isinstance(value, str):
try:
datetime.fromisoformat(value)
except (ValueError, TypeError) as exc:
msg = (
"Can't compare strings using operators '>', '>=', '<', '<='. "
"Strings are only comparable if they are ISO formatted dates."
)
raise FilterError(msg) from exc

return {field: {"$gt": value}}


def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]:
if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON):
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
msg = f"Unsupported type for '>=' comparison: {type(value)}. "
raise FilterError(msg)
elif isinstance(value, str):
try:
datetime.fromisoformat(value)
except (ValueError, TypeError) as exc:
msg = (
"Can't compare strings using operators '>', '>=', '<', '<='. "
"Strings are only comparable if they are ISO formatted dates."
)
raise FilterError(msg) from exc
elif value is None:
# we want {field: {"$gte": null}} to return an empty result
# $gte with null values in MongoDB returns a non-empty result, while $gt aligns with our expectations
return {field: {"$gt": value}}

return {field: {"$gte": value}}


def _less_than(field: str, value: Any) -> Dict[str, Any]:
if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON):
msg = f"Unsupported type for '<' comparison: {type(value)}. "
raise FilterError(msg)
elif isinstance(value, str):
try:
datetime.fromisoformat(value)
except (ValueError, TypeError) as exc:
msg = (
"Can't compare strings using operators '>', '>=', '<', '<='. "
"Strings are only comparable if they are ISO formatted dates."
)
raise FilterError(msg) from exc

return {field: {"$lt": value}}


def _less_than_equal(field: str, value: Any) -> Dict[str, Any]:
if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON):
msg = f"Unsupported type for 'less than equal' comparison: {type(value)}. "
raise FilterError(msg)
elif isinstance(value, str):
try:
datetime.fromisoformat(value)
except (ValueError, TypeError) as exc:
msg = (
"Can't compare strings using operators '>', '>=', '<', '<='. "
"Strings are only comparable if they are ISO formatted dates."
)
raise FilterError(msg) from exc
elif value is None:
# we want {field: {"$lte": null}} to return an empty result
# $lte with null values in MongoDB returns a non-empty result, while $lt aligns with our expectations
return {field: {"$lt": value}}

return {field: {"$lte": value}}


def _not_in(field: str, value: Any) -> Dict[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone"
raise FilterError(msg)

return {field: {"$nin": value}}


def _in(field: str, value: Any) -> Dict[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone"
raise FilterError(msg)

return {field: {"$in": value}}


COMPARISON_OPERATORS = {
"==": _equal,
"!=": _not_equal,
">": _greater_than,
">=": _greater_than_equal,
"<": _less_than,
"<=": _less_than_equal,
"in": _in,
"not in": _not_in,
}
Loading