diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py index d639bbaf4..69fd7cbbd 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py @@ -4,28 +4,55 @@ from haystack.utils.filters import COMPARISON_OPERATORS, LOGICAL_OPERATORS, FilterError from qdrant_client.http import models -from .converters import convert_id - COMPARISON_OPERATORS = COMPARISON_OPERATORS.keys() LOGICAL_OPERATORS = LOGICAL_OPERATORS.keys() def convert_filters_to_qdrant( - filter_term: Optional[Union[List[dict], dict, models.Filter]] = None, -) -> Optional[models.Filter]: - """Converts Haystack filters to the format used by Qdrant.""" + filter_term: Optional[Union[List[dict], dict, models.Filter]] = None, is_parent_call: bool = True +) -> Optional[Union[models.Filter, List[models.Filter], List[models.Condition]]]: + """Converts Haystack filters to the format used by Qdrant. + + :param filter_term: the haystack filter to be converted to qdrant. + :param is_parent_call: indicates if this is the top-level call to the function. If True, the function returns + a single models.Filter object; if False, it may return a list of filters or conditions for further processing. + + :returns: a single Qdrant Filter in the parent call or a list of such Filters in recursive calls. + + :raises FilterError: If the invalid filter criteria is provided or if an unknown operator is encountered. + + """ + if isinstance(filter_term, models.Filter): return filter_term if not filter_term: return None - must_clauses, should_clauses, must_not_clauses = [], [], [] + must_clauses: List[models.Filter] = [] + should_clauses: List[models.Filter] = [] + must_not_clauses: List[models.Filter] = [] + # Indicates if there are multiple same LOGICAL OPERATORS on each level + # and prevents them from being combined + same_operator_flag = False + conditions, qdrant_filter, current_level_operators = ( + [], + [], + [], + ) if isinstance(filter_term, dict): filter_term = [filter_term] + # ======== IDENTIFY FILTER ITEMS ON EACH LEVEL ======== + for item in filter_term: operator = item.get("operator") + + # Check for repeated similar operators on each level + same_operator_flag = operator in current_level_operators and operator in LOGICAL_OPERATORS + if not same_operator_flag: + current_level_operators.append(operator) + if operator is None: msg = "Operator not found in filters" raise FilterError(msg) @@ -34,12 +61,23 @@ def convert_filters_to_qdrant( msg = f"'conditions' not found for '{operator}'" raise FilterError(msg) - if operator == "AND": - must_clauses.append(convert_filters_to_qdrant(item.get("conditions", []))) - elif operator == "OR": - should_clauses.append(convert_filters_to_qdrant(item.get("conditions", []))) - elif operator == "NOT": - must_not_clauses.append(convert_filters_to_qdrant(item.get("conditions", []))) + if operator in LOGICAL_OPERATORS: + # Recursively process nested conditions + current_filter = convert_filters_to_qdrant(item.get("conditions", []), is_parent_call=False) or [] + + # When same_operator_flag is set to True, + # ensure each clause is appended as an independent list to avoid merging distinct clauses. + if operator == "AND": + must_clauses = [must_clauses, current_filter] if same_operator_flag else must_clauses + current_filter + elif operator == "OR": + should_clauses = ( + [should_clauses, current_filter] if same_operator_flag else should_clauses + current_filter + ) + elif operator == "NOT": + must_not_clauses = ( + [must_not_clauses, current_filter] if same_operator_flag else must_not_clauses + current_filter + ) + elif operator in COMPARISON_OPERATORS: field = item.get("field") value = item.get("value") @@ -47,20 +85,106 @@ def convert_filters_to_qdrant( msg = f"'field' or 'value' not found for '{operator}'" raise FilterError(msg) - must_clauses.extend(_parse_comparison_operation(comparison_operation=operator, key=field, value=value)) + parsed_conditions = _parse_comparison_operation(comparison_operation=operator, key=field, value=value) + + # check if the parsed_conditions are models.Filter or models.Condition + for condition in parsed_conditions: + if isinstance(condition, models.Filter): + qdrant_filter.append(condition) + else: + conditions.append(condition) + else: msg = f"Unknown operator {operator} used in filters" raise FilterError(msg) - payload_filter = models.Filter( - must=must_clauses or None, - should=should_clauses or None, - must_not=must_not_clauses or None, - ) + # ======== PROCESS FILTER ITEMS ON EACH LEVEL ======== + + # If same logical operators have separate clauses, create separate filters + if same_operator_flag: + qdrant_filter = build_filters_for_repeated_operators( + must_clauses, should_clauses, must_not_clauses, qdrant_filter + ) + + # else append a single Filter for existing clauses + elif must_clauses or should_clauses or must_not_clauses: + qdrant_filter.append( + models.Filter( + must=must_clauses or None, + should=should_clauses or None, + must_not=must_not_clauses or None, + ) + ) + + # In case of parent call, a single Filter is returned + if is_parent_call: + # If qdrant_filter has just a single Filter in parent call, + # then it might be returned instead. + if len(qdrant_filter) == 1 and isinstance(qdrant_filter[0], models.Filter): + return qdrant_filter[0] + else: + must_clauses.extend(conditions) + return models.Filter( + must=must_clauses or None, + should=should_clauses or None, + must_not=must_not_clauses or None, + ) + + # Store conditions of each level in output of the loop + elif conditions: + qdrant_filter.extend(conditions) + + return qdrant_filter + + +def build_filters_for_repeated_operators( + must_clauses, + should_clauses, + must_not_clauses, + qdrant_filter, +) -> List[models.Filter]: + """ + Flattens the nested lists of clauses by creating separate Filters for each clause of a logical operator. + + :param must_clauses: a nested list of must clauses or an empty list. + :param should_clauses: a nested list of should clauses or an empty list. + :param must_not_clauses: a nested list of must_not clauses or an empty list. + :param qdrant_filter: a list where the generated Filter objects will be appended. + This list will be modified in-place. - filter_result = _squeeze_filter(payload_filter) - return filter_result + :returns: the modified `qdrant_filter` list with appended generated Filter objects. + """ + + if any(isinstance(i, list) for i in must_clauses): + for i in must_clauses: + qdrant_filter.append( + models.Filter( + must=i or None, + should=should_clauses or None, + must_not=must_not_clauses or None, + ) + ) + if any(isinstance(i, list) for i in should_clauses): + for i in should_clauses: + qdrant_filter.append( + models.Filter( + must=must_clauses or None, + should=i or None, + must_not=must_not_clauses or None, + ) + ) + if any(isinstance(i, list) for i in must_not_clauses): + for i in must_clauses: + qdrant_filter.append( + models.Filter( + must=must_clauses or None, + should=should_clauses or None, + must_not=i or None, + ) + ) + + return qdrant_filter def _parse_comparison_operation( @@ -92,7 +216,7 @@ def _parse_comparison_operation( def _build_eq_condition(key: str, value: models.ValueVariants) -> models.Condition: if isinstance(value, str) and " " in value: - models.FieldCondition(key=key, match=models.MatchText(text=value)) + return models.FieldCondition(key=key, match=models.MatchText(text=value)) return models.FieldCondition(key=key, match=models.MatchValue(value=value)) @@ -184,52 +308,6 @@ def _build_gte_condition(key: str, value: Union[str, float, int]) -> models.Cond raise FilterError(msg) -def _build_has_id_condition(id_values: List[models.ExtendedPointId]) -> models.HasIdCondition: - return models.HasIdCondition( - has_id=[ - # Ids are converted into their internal representation - convert_id(item) - for item in id_values - ] - ) - - -def _squeeze_filter(payload_filter: models.Filter) -> models.Filter: - """ - Simplify given payload filter, if the nested structure might be unnested. - That happens if there is a single clause in that filter. - :param payload_filter: - :returns: - """ - filter_parts = { - "must": payload_filter.must, - "should": payload_filter.should, - "must_not": payload_filter.must_not, - } - - total_clauses = sum(len(x) for x in filter_parts.values() if x is not None) - if total_clauses == 0 or total_clauses > 1: - return payload_filter - - # Payload filter has just a single clause provided (either must, should - # or must_not). If that single clause is also of a models.Filter type, - # then it might be returned instead. - for part_name, filter_part in filter_parts.items(): - if not filter_part: - continue - - subfilter = filter_part[0] - if not isinstance(subfilter, models.Filter): - # The inner statement is a simple condition like models.FieldCondition - # so it cannot be simplified. - continue - - if subfilter.must: - return models.Filter(**{part_name: subfilter.must}) - - return payload_filter - - def is_datetime_string(value: str) -> bool: try: datetime.fromisoformat(value) diff --git a/integrations/qdrant/tests/test_filters.py b/integrations/qdrant/tests/test_filters.py index 016ff57b6..fd070bda9 100644 --- a/integrations/qdrant/tests/test_filters.py +++ b/integrations/qdrant/tests/test_filters.py @@ -61,6 +61,112 @@ def test_not_operator(self, document_store, filterable_docs): [d for d in filterable_docs if (d.meta.get("number") != 100 and d.meta.get("name") != "name_0")], ) + def test_filter_criteria(self, document_store): + documents = [ + Document( + content="This is test document 1.", + meta={"file_name": "file1", "classification": {"details": {"category1": 0.9, "category2": 0.3}}}, + ), + Document( + content="This is test document 2.", + meta={"file_name": "file2", "classification": {"details": {"category1": 0.1, "category2": 0.7}}}, + ), + Document( + content="This is test document 3.", + meta={"file_name": "file3", "classification": {"details": {"category1": 0.7, "category2": 0.9}}}, + ), + ] + + document_store.write_documents(documents) + filter_criteria = { + "operator": "AND", + "conditions": [ + {"field": "meta.file_name", "operator": "in", "value": ["file1", "file2"]}, + { + "operator": "OR", + "conditions": [ + {"field": "meta.classification.details.category1", "operator": ">=", "value": 0.85}, + {"field": "meta.classification.details.category2", "operator": ">=", "value": 0.85}, + ], + }, + ], + } + result = document_store.filter_documents(filter_criteria) + self.assert_documents_are_equal( + result, + [ + d + for d in documents + if (d.meta.get("file_name") in ["file1", "file2"]) + and ( + (d.meta.get("classification").get("details").get("category1") >= 0.85) + or (d.meta.get("classification").get("details").get("category2") >= 0.85) + ) + ], + ) + + def test_complex_filter_criteria(self, document_store): + documents = [ + Document( + content="This is test document 1.", + meta={ + "file_name": "file1", + "classification": {"details": {"category1": 0.45, "category2": 0.5, "category3": 0.2}}, + }, + ), + Document( + content="This is test document 2.", + meta={ + "file_name": "file2", + "classification": {"details": {"category1": 0.95, "category2": 0.85, "category3": 0.4}}, + }, + ), + Document( + content="This is test document 3.", + meta={ + "file_name": "file3", + "classification": {"details": {"category1": 0.85, "category2": 0.7, "category3": 0.95}}, + }, + ), + ] + + document_store.write_documents(documents) + filter_criteria = { + "operator": "AND", + "conditions": [ + {"field": "meta.file_name", "operator": "in", "value": ["file1", "file2", "file3"]}, + { + "operator": "AND", + "conditions": [ + {"field": "meta.classification.details.category1", "operator": ">=", "value": 0.85}, + { + "operator": "OR", + "conditions": [ + {"field": "meta.classification.details.category2", "operator": ">=", "value": 0.8}, + {"field": "meta.classification.details.category3", "operator": ">=", "value": 0.9}, + ], + }, + ], + }, + ], + } + result = document_store.filter_documents(filter_criteria) + self.assert_documents_are_equal( + result, + [ + d + for d in documents + if (d.meta.get("file_name") in ["file1", "file2", "file3"]) + and ( + (d.meta.get("classification").get("details").get("category1") >= 0.85) + and ( + (d.meta.get("classification").get("details").get("category2") >= 0.8) + or (d.meta.get("classification").get("details").get("category3") >= 0.9) + ) + ) + ], + ) + # ======== OVERRIDES FOR NONE VALUED FILTERS ======== def test_comparison_equal_with_none(self, document_store, filterable_docs):