From 435da6a5c59131c9e9090c2f263a87e993427224 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 22 Dec 2023 16:05:49 +0100 Subject: [PATCH] filters --- .../src/pinecone_haystack/document_store.py | 10 +- .../pinecone/src/pinecone_haystack/filters.py | 193 ++++++++++++++++++ integrations/pinecone/tests/test_filters.py | 81 ++++++++ 3 files changed, 283 insertions(+), 1 deletion(-) create mode 100644 integrations/pinecone/src/pinecone_haystack/filters.py create mode 100644 integrations/pinecone/tests/test_filters.py diff --git a/integrations/pinecone/src/pinecone_haystack/document_store.py b/integrations/pinecone/src/pinecone_haystack/document_store.py index 576993de6..d6296e030 100644 --- a/integrations/pinecone/src/pinecone_haystack/document_store.py +++ b/integrations/pinecone/src/pinecone_haystack/document_store.py @@ -12,6 +12,9 @@ from haystack import default_to_dict from haystack.dataclasses import Document from haystack.document_stores import DuplicatePolicy +from haystack.utils.filters import convert + +from pinecone_haystack.filters import _normalize_filters logger = logging.getLogger(__name__) @@ -178,7 +181,7 @@ def _embedding_retrieval( query_embedding: List[float], *, namespace: Optional[str] = None, - filters: Optional[Dict[str, Any]] = None, # noqa: ARG002 (filters to be implemented) + filters: Optional[Dict[str, Any]] = None, top_k: int = 10, ) -> List[Document]: """ @@ -200,10 +203,15 @@ def _embedding_retrieval( msg = "query_embedding must be a non-empty list of floats" raise ValueError(msg) + if filters and "operator" not in filters and "conditions" not in filters: + filters = convert(filters) + filters = _normalize_filters(filters) if filters else None + result = self._index.query( vector=query_embedding, top_k=top_k, namespace=namespace or self.namespace, + filter=filters, include_values=True, include_metadata=True, ) diff --git a/integrations/pinecone/src/pinecone_haystack/filters.py b/integrations/pinecone/src/pinecone_haystack/filters.py new file mode 100644 index 000000000..805162609 --- /dev/null +++ b/integrations/pinecone/src/pinecone_haystack/filters.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict + +from haystack.errors import FilterError +from pandas import DataFrame + + +def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]: + """ + Converts Haystack filters in Pinecone compatible filters. + Reference: https://docs.pinecone.io/docs/metadata-filtering + """ + if not isinstance(filters, dict): + msg = "Filters must be a dictionary" + raise FilterError(msg) + + if "field" in filters: + return _parse_comparison_condition(filters) + return _parse_logical_condition(filters) + + +def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "conditions" not in condition: + msg = f"'conditions' key missing in {condition}" + raise FilterError(msg) + + operator = condition["operator"] + conditions = [_parse_comparison_condition(c) for c in condition["conditions"]] + + if operator in LOGICAL_OPERATORS: + return {LOGICAL_OPERATORS[operator]: conditions} + + msg = f"Unknown logical operator '{operator}'" + raise FilterError(msg) + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: + if "field" not in condition: + # 'field' key is only found in comparison dictionaries. + # We assume this is a logic dictionary since it's not present. + return _parse_logical_condition(condition) + + field: str = condition["field"] + + if field.startswith("meta."): + # Remove the "meta." prefix if present. + # Documents are flattened when using the PineconeDocumentStore + # so we don't need to specify the "meta." prefix. + # Instead of raising an error we handle it gracefully. + field = field[5:] + + # if field == "content": + # field = "meta.content" + # if field == "dataframe": + # field = "meta.dataframe" + + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "value" not in condition: + msg = f"'value' key missing in {condition}" + raise FilterError(msg) + operator: str = condition["operator"] + value: Any = condition["value"] + if isinstance(value, DataFrame): + value = value.to_json() + + return COMPARISON_OPERATORS[operator](field, value) + + +def _equal(field: str, value: Any) -> Dict[str, Any]: + supported_types = (str, int, float, bool) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'equal' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$eq": value}} + + +def _not_equal(field: str, value: Any) -> Dict[str, Any]: + supported_types = (str, int, float, bool) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'inequal' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$ne": value}} + + +def _greater_than(field: str, value: Any) -> Dict[str, Any]: + supported_types = (int, float) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'greater than' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$gt": value}} + + +def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: + supported_types = (int, float) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'greater than equal' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$gte": value}} + + +def _less_than(field: str, value: Any) -> Dict[str, Any]: + supported_types = (int, float) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'less than' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$lt": value}} + + +def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: + supported_types = (int, float) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'less than equal' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$lte": value}} + + +def _not_in(field: str, value: Any) -> Dict[str, Any]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" + raise FilterError(msg) + + supported_types = (int, float, str) + for v in value: + if not isinstance(v, supported_types): + msg = ( + f"Unsupported type for 'not in' comparison: {type(v)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$nin": value}} + + +def _in(field: str, value: Any) -> Dict[str, Any]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone" + raise FilterError(msg) + + supported_types = (int, float, str) + for v in value: + if not isinstance(v, supported_types): + msg = ( + f"Unsupported type for 'in' comparison: {type(v)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$in": value}} + + +COMPARISON_OPERATORS = { + "==": _equal, + "!=": _not_equal, + ">": _greater_than, + ">=": _greater_than_equal, + "<": _less_than, + "<=": _less_than_equal, + "in": _in, + "not in": _not_in, +} + +LOGICAL_OPERATORS = {"AND": "$and", "OR": "$or"} diff --git a/integrations/pinecone/tests/test_filters.py b/integrations/pinecone/tests/test_filters.py new file mode 100644 index 000000000..1e6aeb0cd --- /dev/null +++ b/integrations/pinecone/tests/test_filters.py @@ -0,0 +1,81 @@ +from typing import List + +import pytest +from haystack.dataclasses.document import Document +from haystack.testing.document_store import ( + FilterDocumentsTest, +) + + +class TestFilters(FilterDocumentsTest): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + for doc in received: + # Pinecone seems to convert strings to datetime objects (undocumented behavior) + # We convert them back to strings to compare them + if "date" in doc.meta: + doc.meta["date"] = doc.meta["date"].isoformat() + # Pinecone seems to convert integers to floats (undocumented behavior) + # We convert them back to integers to compare them + if "number" in doc.meta: + doc.meta["number"] = int(doc.meta["number"]) + + # Lists comparison + assert len(received) == len(expected) + received.sort(key=lambda x: x.id) + expected.sort(key=lambda x: x.id) + for received_doc, expected_doc in zip(received, expected): + assert received_doc.meta == expected_doc.meta + assert received_doc.content == expected_doc.content + if received_doc.dataframe is None: + assert expected_doc.dataframe is None + else: + assert received_doc.dataframe.equals(expected_doc.dataframe) + # unfortunately, Pinecone returns a slightly different embedding + if received_doc.embedding is None: + assert expected_doc.embedding is None + else: + assert received_doc.embedding == pytest.approx(expected_doc.embedding) + + @pytest.mark.skip(reason="Pinecone does not support comparison with null values") + def test_comparison_equal_with_none(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with null values") + def test_comparison_not_equal_with_none(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with dates") + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with null values") + def test_comparison_greater_than_with_none(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with dates") + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with null values") + def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with dates") + def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with null values") + def test_comparison_less_than_with_none(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with dates") + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with null values") + def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support the 'not' operator") + def test_not_operator(self, document_store, filterable_docs): + ...