Skip to content

Commit

Permalink
fix: errors in convert_filters_to_qdrant (#870)
Browse files Browse the repository at this point in the history
* progress

* Fixed logic error

* Some tests are still failing

* Passed all tests

* Fixed errors in logic

* Fixed linting issues

* Minor adjustments

* Further improvements in code structure

* Final changes for review

* Updated

* Added more tests

* Add a test to check nested filters

* Minor changes

* Fix bugs and add docstrings

---------

Co-authored-by: Amna Mubashar <[email protected]>
  • Loading branch information
Amnah199 and Amna Mubashar authored Jul 10, 2024
1 parent 1ecfbfa commit c8f19a2
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,55 @@
from haystack.utils.filters import COMPARISON_OPERATORS, LOGICAL_OPERATORS, FilterError
from qdrant_client.http import models

from .converters import convert_id

COMPARISON_OPERATORS = COMPARISON_OPERATORS.keys()
LOGICAL_OPERATORS = LOGICAL_OPERATORS.keys()


def convert_filters_to_qdrant(
filter_term: Optional[Union[List[dict], dict, models.Filter]] = None,
) -> Optional[models.Filter]:
"""Converts Haystack filters to the format used by Qdrant."""
filter_term: Optional[Union[List[dict], dict, models.Filter]] = None, is_parent_call: bool = True
) -> Optional[Union[models.Filter, List[models.Filter], List[models.Condition]]]:
"""Converts Haystack filters to the format used by Qdrant.
:param filter_term: the haystack filter to be converted to qdrant.
:param is_parent_call: indicates if this is the top-level call to the function. If True, the function returns
a single models.Filter object; if False, it may return a list of filters or conditions for further processing.
:returns: a single Qdrant Filter in the parent call or a list of such Filters in recursive calls.
:raises FilterError: If the invalid filter criteria is provided or if an unknown operator is encountered.
"""

if isinstance(filter_term, models.Filter):
return filter_term
if not filter_term:
return None

must_clauses, should_clauses, must_not_clauses = [], [], []
must_clauses: List[models.Filter] = []
should_clauses: List[models.Filter] = []
must_not_clauses: List[models.Filter] = []
# Indicates if there are multiple same LOGICAL OPERATORS on each level
# and prevents them from being combined
same_operator_flag = False
conditions, qdrant_filter, current_level_operators = (
[],
[],
[],
)

if isinstance(filter_term, dict):
filter_term = [filter_term]

# ======== IDENTIFY FILTER ITEMS ON EACH LEVEL ========

for item in filter_term:
operator = item.get("operator")

# Check for repeated similar operators on each level
same_operator_flag = operator in current_level_operators and operator in LOGICAL_OPERATORS
if not same_operator_flag:
current_level_operators.append(operator)

if operator is None:
msg = "Operator not found in filters"
raise FilterError(msg)
Expand All @@ -34,33 +61,130 @@ def convert_filters_to_qdrant(
msg = f"'conditions' not found for '{operator}'"
raise FilterError(msg)

if operator == "AND":
must_clauses.append(convert_filters_to_qdrant(item.get("conditions", [])))
elif operator == "OR":
should_clauses.append(convert_filters_to_qdrant(item.get("conditions", [])))
elif operator == "NOT":
must_not_clauses.append(convert_filters_to_qdrant(item.get("conditions", [])))
if operator in LOGICAL_OPERATORS:
# Recursively process nested conditions
current_filter = convert_filters_to_qdrant(item.get("conditions", []), is_parent_call=False) or []

# When same_operator_flag is set to True,
# ensure each clause is appended as an independent list to avoid merging distinct clauses.
if operator == "AND":
must_clauses = [must_clauses, current_filter] if same_operator_flag else must_clauses + current_filter
elif operator == "OR":
should_clauses = (
[should_clauses, current_filter] if same_operator_flag else should_clauses + current_filter
)
elif operator == "NOT":
must_not_clauses = (
[must_not_clauses, current_filter] if same_operator_flag else must_not_clauses + current_filter
)

elif operator in COMPARISON_OPERATORS:
field = item.get("field")
value = item.get("value")
if field is None or value is None:
msg = f"'field' or 'value' not found for '{operator}'"
raise FilterError(msg)

must_clauses.extend(_parse_comparison_operation(comparison_operation=operator, key=field, value=value))
parsed_conditions = _parse_comparison_operation(comparison_operation=operator, key=field, value=value)

# check if the parsed_conditions are models.Filter or models.Condition
for condition in parsed_conditions:
if isinstance(condition, models.Filter):
qdrant_filter.append(condition)
else:
conditions.append(condition)

else:
msg = f"Unknown operator {operator} used in filters"
raise FilterError(msg)

payload_filter = models.Filter(
must=must_clauses or None,
should=should_clauses or None,
must_not=must_not_clauses or None,
)
# ======== PROCESS FILTER ITEMS ON EACH LEVEL ========

# If same logical operators have separate clauses, create separate filters
if same_operator_flag:
qdrant_filter = build_filters_for_repeated_operators(
must_clauses, should_clauses, must_not_clauses, qdrant_filter
)

# else append a single Filter for existing clauses
elif must_clauses or should_clauses or must_not_clauses:
qdrant_filter.append(
models.Filter(
must=must_clauses or None,
should=should_clauses or None,
must_not=must_not_clauses or None,
)
)

# In case of parent call, a single Filter is returned
if is_parent_call:
# If qdrant_filter has just a single Filter in parent call,
# then it might be returned instead.
if len(qdrant_filter) == 1 and isinstance(qdrant_filter[0], models.Filter):
return qdrant_filter[0]
else:
must_clauses.extend(conditions)
return models.Filter(
must=must_clauses or None,
should=should_clauses or None,
must_not=must_not_clauses or None,
)

# Store conditions of each level in output of the loop
elif conditions:
qdrant_filter.extend(conditions)

return qdrant_filter


def build_filters_for_repeated_operators(
must_clauses,
should_clauses,
must_not_clauses,
qdrant_filter,
) -> List[models.Filter]:
"""
Flattens the nested lists of clauses by creating separate Filters for each clause of a logical operator.
:param must_clauses: a nested list of must clauses or an empty list.
:param should_clauses: a nested list of should clauses or an empty list.
:param must_not_clauses: a nested list of must_not clauses or an empty list.
:param qdrant_filter: a list where the generated Filter objects will be appended.
This list will be modified in-place.
filter_result = _squeeze_filter(payload_filter)
return filter_result
:returns: the modified `qdrant_filter` list with appended generated Filter objects.
"""

if any(isinstance(i, list) for i in must_clauses):
for i in must_clauses:
qdrant_filter.append(
models.Filter(
must=i or None,
should=should_clauses or None,
must_not=must_not_clauses or None,
)
)
if any(isinstance(i, list) for i in should_clauses):
for i in should_clauses:
qdrant_filter.append(
models.Filter(
must=must_clauses or None,
should=i or None,
must_not=must_not_clauses or None,
)
)
if any(isinstance(i, list) for i in must_not_clauses):
for i in must_clauses:
qdrant_filter.append(
models.Filter(
must=must_clauses or None,
should=should_clauses or None,
must_not=i or None,
)
)

return qdrant_filter


def _parse_comparison_operation(
Expand Down Expand Up @@ -92,7 +216,7 @@ def _parse_comparison_operation(

def _build_eq_condition(key: str, value: models.ValueVariants) -> models.Condition:
if isinstance(value, str) and " " in value:
models.FieldCondition(key=key, match=models.MatchText(text=value))
return models.FieldCondition(key=key, match=models.MatchText(text=value))
return models.FieldCondition(key=key, match=models.MatchValue(value=value))


Expand Down Expand Up @@ -184,52 +308,6 @@ def _build_gte_condition(key: str, value: Union[str, float, int]) -> models.Cond
raise FilterError(msg)


def _build_has_id_condition(id_values: List[models.ExtendedPointId]) -> models.HasIdCondition:
return models.HasIdCondition(
has_id=[
# Ids are converted into their internal representation
convert_id(item)
for item in id_values
]
)


def _squeeze_filter(payload_filter: models.Filter) -> models.Filter:
"""
Simplify given payload filter, if the nested structure might be unnested.
That happens if there is a single clause in that filter.
:param payload_filter:
:returns:
"""
filter_parts = {
"must": payload_filter.must,
"should": payload_filter.should,
"must_not": payload_filter.must_not,
}

total_clauses = sum(len(x) for x in filter_parts.values() if x is not None)
if total_clauses == 0 or total_clauses > 1:
return payload_filter

# Payload filter has just a single clause provided (either must, should
# or must_not). If that single clause is also of a models.Filter type,
# then it might be returned instead.
for part_name, filter_part in filter_parts.items():
if not filter_part:
continue

subfilter = filter_part[0]
if not isinstance(subfilter, models.Filter):
# The inner statement is a simple condition like models.FieldCondition
# so it cannot be simplified.
continue

if subfilter.must:
return models.Filter(**{part_name: subfilter.must})

return payload_filter


def is_datetime_string(value: str) -> bool:
try:
datetime.fromisoformat(value)
Expand Down
106 changes: 106 additions & 0 deletions integrations/qdrant/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,112 @@ def test_not_operator(self, document_store, filterable_docs):
[d for d in filterable_docs if (d.meta.get("number") != 100 and d.meta.get("name") != "name_0")],
)

def test_filter_criteria(self, document_store):
documents = [
Document(
content="This is test document 1.",
meta={"file_name": "file1", "classification": {"details": {"category1": 0.9, "category2": 0.3}}},
),
Document(
content="This is test document 2.",
meta={"file_name": "file2", "classification": {"details": {"category1": 0.1, "category2": 0.7}}},
),
Document(
content="This is test document 3.",
meta={"file_name": "file3", "classification": {"details": {"category1": 0.7, "category2": 0.9}}},
),
]

document_store.write_documents(documents)
filter_criteria = {
"operator": "AND",
"conditions": [
{"field": "meta.file_name", "operator": "in", "value": ["file1", "file2"]},
{
"operator": "OR",
"conditions": [
{"field": "meta.classification.details.category1", "operator": ">=", "value": 0.85},
{"field": "meta.classification.details.category2", "operator": ">=", "value": 0.85},
],
},
],
}
result = document_store.filter_documents(filter_criteria)
self.assert_documents_are_equal(
result,
[
d
for d in documents
if (d.meta.get("file_name") in ["file1", "file2"])
and (
(d.meta.get("classification").get("details").get("category1") >= 0.85)
or (d.meta.get("classification").get("details").get("category2") >= 0.85)
)
],
)

def test_complex_filter_criteria(self, document_store):
documents = [
Document(
content="This is test document 1.",
meta={
"file_name": "file1",
"classification": {"details": {"category1": 0.45, "category2": 0.5, "category3": 0.2}},
},
),
Document(
content="This is test document 2.",
meta={
"file_name": "file2",
"classification": {"details": {"category1": 0.95, "category2": 0.85, "category3": 0.4}},
},
),
Document(
content="This is test document 3.",
meta={
"file_name": "file3",
"classification": {"details": {"category1": 0.85, "category2": 0.7, "category3": 0.95}},
},
),
]

document_store.write_documents(documents)
filter_criteria = {
"operator": "AND",
"conditions": [
{"field": "meta.file_name", "operator": "in", "value": ["file1", "file2", "file3"]},
{
"operator": "AND",
"conditions": [
{"field": "meta.classification.details.category1", "operator": ">=", "value": 0.85},
{
"operator": "OR",
"conditions": [
{"field": "meta.classification.details.category2", "operator": ">=", "value": 0.8},
{"field": "meta.classification.details.category3", "operator": ">=", "value": 0.9},
],
},
],
},
],
}
result = document_store.filter_documents(filter_criteria)
self.assert_documents_are_equal(
result,
[
d
for d in documents
if (d.meta.get("file_name") in ["file1", "file2", "file3"])
and (
(d.meta.get("classification").get("details").get("category1") >= 0.85)
and (
(d.meta.get("classification").get("details").get("category2") >= 0.8)
or (d.meta.get("classification").get("details").get("category3") >= 0.9)
)
)
],
)

# ======== OVERRIDES FOR NONE VALUED FILTERS ========

def test_comparison_equal_with_none(self, document_store, filterable_docs):
Expand Down

0 comments on commit c8f19a2

Please sign in to comment.