Skip to content

Commit

Permalink
filters
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Dec 22, 2023
1 parent fbdb9a0 commit 435da6a
Show file tree
Hide file tree
Showing 3 changed files with 283 additions and 1 deletion.
10 changes: 9 additions & 1 deletion integrations/pinecone/src/pinecone_haystack/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from haystack import default_to_dict
from haystack.dataclasses import Document
from haystack.document_stores import DuplicatePolicy
from haystack.utils.filters import convert

from pinecone_haystack.filters import _normalize_filters

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -178,7 +181,7 @@ def _embedding_retrieval(
query_embedding: List[float],
*,
namespace: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None, # noqa: ARG002 (filters to be implemented)
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
) -> List[Document]:
"""
Expand All @@ -200,10 +203,15 @@ def _embedding_retrieval(
msg = "query_embedding must be a non-empty list of floats"
raise ValueError(msg)

if filters and "operator" not in filters and "conditions" not in filters:
filters = convert(filters)
filters = _normalize_filters(filters) if filters else None

result = self._index.query(
vector=query_embedding,
top_k=top_k,
namespace=namespace or self.namespace,
filter=filters,
include_values=True,
include_metadata=True,
)
Expand Down
193 changes: 193 additions & 0 deletions integrations/pinecone/src/pinecone_haystack/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict

from haystack.errors import FilterError
from pandas import DataFrame


def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]:
"""
Converts Haystack filters in Pinecone compatible filters.
Reference: https://docs.pinecone.io/docs/metadata-filtering
"""
if not isinstance(filters, dict):
msg = "Filters must be a dictionary"
raise FilterError(msg)

if "field" in filters:
return _parse_comparison_condition(filters)
return _parse_logical_condition(filters)


def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "conditions" not in condition:
msg = f"'conditions' key missing in {condition}"
raise FilterError(msg)

operator = condition["operator"]
conditions = [_parse_comparison_condition(c) for c in condition["conditions"]]

if operator in LOGICAL_OPERATORS:
return {LOGICAL_OPERATORS[operator]: conditions}

msg = f"Unknown logical operator '{operator}'"
raise FilterError(msg)


def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
if "field" not in condition:
# 'field' key is only found in comparison dictionaries.
# We assume this is a logic dictionary since it's not present.
return _parse_logical_condition(condition)

field: str = condition["field"]

if field.startswith("meta."):
# Remove the "meta." prefix if present.
# Documents are flattened when using the PineconeDocumentStore
# so we don't need to specify the "meta." prefix.
# Instead of raising an error we handle it gracefully.
field = field[5:]

# if field == "content":
# field = "meta.content"
# if field == "dataframe":
# field = "meta.dataframe"

if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "value" not in condition:
msg = f"'value' key missing in {condition}"
raise FilterError(msg)
operator: str = condition["operator"]
value: Any = condition["value"]
if isinstance(value, DataFrame):
value = value.to_json()

return COMPARISON_OPERATORS[operator](field, value)


def _equal(field: str, value: Any) -> Dict[str, Any]:
supported_types = (str, int, float, bool)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'equal' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$eq": value}}


def _not_equal(field: str, value: Any) -> Dict[str, Any]:
supported_types = (str, int, float, bool)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'inequal' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$ne": value}}


def _greater_than(field: str, value: Any) -> Dict[str, Any]:
supported_types = (int, float)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'greater than' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$gt": value}}


def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]:
supported_types = (int, float)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'greater than equal' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$gte": value}}


def _less_than(field: str, value: Any) -> Dict[str, Any]:
supported_types = (int, float)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'less than' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$lt": value}}


def _less_than_equal(field: str, value: Any) -> Dict[str, Any]:
supported_types = (int, float)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'less than equal' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$lte": value}}


def _not_in(field: str, value: Any) -> Dict[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone"
raise FilterError(msg)

supported_types = (int, float, str)
for v in value:
if not isinstance(v, supported_types):
msg = (
f"Unsupported type for 'not in' comparison: {type(v)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$nin": value}}


def _in(field: str, value: Any) -> Dict[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone"
raise FilterError(msg)

supported_types = (int, float, str)
for v in value:
if not isinstance(v, supported_types):
msg = (
f"Unsupported type for 'in' comparison: {type(v)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$in": value}}


COMPARISON_OPERATORS = {
"==": _equal,
"!=": _not_equal,
">": _greater_than,
">=": _greater_than_equal,
"<": _less_than,
"<=": _less_than_equal,
"in": _in,
"not in": _not_in,
}

LOGICAL_OPERATORS = {"AND": "$and", "OR": "$or"}
81 changes: 81 additions & 0 deletions integrations/pinecone/tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import List

import pytest
from haystack.dataclasses.document import Document
from haystack.testing.document_store import (
FilterDocumentsTest,
)


class TestFilters(FilterDocumentsTest):
def assert_documents_are_equal(self, received: List[Document], expected: List[Document]):
for doc in received:
# Pinecone seems to convert strings to datetime objects (undocumented behavior)
# We convert them back to strings to compare them
if "date" in doc.meta:
doc.meta["date"] = doc.meta["date"].isoformat()
# Pinecone seems to convert integers to floats (undocumented behavior)
# We convert them back to integers to compare them
if "number" in doc.meta:
doc.meta["number"] = int(doc.meta["number"])

# Lists comparison
assert len(received) == len(expected)
received.sort(key=lambda x: x.id)
expected.sort(key=lambda x: x.id)
for received_doc, expected_doc in zip(received, expected):
assert received_doc.meta == expected_doc.meta
assert received_doc.content == expected_doc.content
if received_doc.dataframe is None:
assert expected_doc.dataframe is None
else:
assert received_doc.dataframe.equals(expected_doc.dataframe)
# unfortunately, Pinecone returns a slightly different embedding
if received_doc.embedding is None:
assert expected_doc.embedding is None
else:
assert received_doc.embedding == pytest.approx(expected_doc.embedding)

@pytest.mark.skip(reason="Pinecone does not support comparison with null values")
def test_comparison_equal_with_none(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with null values")
def test_comparison_not_equal_with_none(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with dates")
def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with null values")
def test_comparison_greater_than_with_none(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with dates")
def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with null values")
def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with dates")
def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with null values")
def test_comparison_less_than_with_none(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with dates")
def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with null values")
def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support the 'not' operator")
def test_not_operator(self, document_store, filterable_docs):
...

0 comments on commit 435da6a

Please sign in to comment.