Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pinecone - filters #133

Merged
merged 5 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from haystack import default_to_dict
from haystack.dataclasses import Document
from haystack.document_stores import DuplicatePolicy
from haystack.utils.filters import convert

from pinecone_haystack.filters import _normalize_filters

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -178,7 +181,7 @@ def _embedding_retrieval(
query_embedding: List[float],
*,
namespace: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None, # noqa: ARG002 (filters to be implemented)
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
) -> List[Document]:
"""
Expand All @@ -200,10 +203,15 @@ def _embedding_retrieval(
msg = "query_embedding must be a non-empty list of floats"
raise ValueError(msg)

if filters and "operator" not in filters and "conditions" not in filters:
filters = convert(filters)
filters = _normalize_filters(filters) if filters else None

result = self._index.query(
vector=query_embedding,
top_k=top_k,
namespace=namespace or self.namespace,
filter=filters,
include_values=True,
include_metadata=True,
)
Expand Down
187 changes: 187 additions & 0 deletions integrations/pinecone/src/pinecone_haystack/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict

from haystack.errors import FilterError
from pandas import DataFrame


def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]:
"""
Converts Haystack filters in Pinecone compatible filters.
Reference: https://docs.pinecone.io/docs/metadata-filtering
"""
if not isinstance(filters, dict):
msg = "Filters must be a dictionary"
raise FilterError(msg)

if "field" in filters:
return _parse_comparison_condition(filters)
return _parse_logical_condition(filters)


def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "conditions" not in condition:
msg = f"'conditions' key missing in {condition}"
raise FilterError(msg)

operator = condition["operator"]
conditions = [_parse_comparison_condition(c) for c in condition["conditions"]]

if operator in LOGICAL_OPERATORS:
return {LOGICAL_OPERATORS[operator]: conditions}

msg = f"Unknown logical operator '{operator}'"
raise FilterError(msg)


def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
if "field" not in condition:
# 'field' key is only found in comparison dictionaries.
# We assume this is a logic dictionary since it's not present.
return _parse_logical_condition(condition)

field: str = condition["field"]
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "value" not in condition:
msg = f"'value' key missing in {condition}"
raise FilterError(msg)
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
operator: str = condition["operator"]
if field.startswith("meta."):
# Remove the "meta." prefix if present.
# Documents are flattened when using the PineconeDocumentStore
# so we don't need to specify the "meta." prefix.
# Instead of raising an error we handle it gracefully.
field = field[5:]

value: Any = condition["value"]
if isinstance(value, DataFrame):
value = value.to_json()

return COMPARISON_OPERATORS[operator](field, value)


def _equal(field: str, value: Any) -> Dict[str, Any]:
supported_types = (str, int, float, bool)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'equal' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$eq": value}}


def _not_equal(field: str, value: Any) -> Dict[str, Any]:
supported_types = (str, int, float, bool)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'not equal' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$ne": value}}


def _greater_than(field: str, value: Any) -> Dict[str, Any]:
supported_types = (int, float)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'greater than' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$gt": value}}


def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]:
supported_types = (int, float)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'greater than equal' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$gte": value}}


def _less_than(field: str, value: Any) -> Dict[str, Any]:
supported_types = (int, float)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'less than' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$lt": value}}


def _less_than_equal(field: str, value: Any) -> Dict[str, Any]:
supported_types = (int, float)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'less than equal' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$lte": value}}


def _not_in(field: str, value: Any) -> Dict[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone"
raise FilterError(msg)

supported_types = (int, float, str)
for v in value:
if not isinstance(v, supported_types):
msg = (
f"Unsupported type for 'not in' comparison: {type(v)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$nin": value}}


def _in(field: str, value: Any) -> Dict[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone"
raise FilterError(msg)

supported_types = (int, float, str)
for v in value:
if not isinstance(v, supported_types):
msg = (
f"Unsupported type for 'in' comparison: {type(v)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$in": value}}


COMPARISON_OPERATORS = {
"==": _equal,
"!=": _not_equal,
">": _greater_than,
">=": _greater_than_equal,
"<": _less_than,
"<=": _less_than_equal,
"in": _in,
"not in": _not_in,
}

LOGICAL_OPERATORS = {"AND": "$and", "OR": "$or"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put module constants at the top

Copy link
Member Author

@anakin87 anakin87 Dec 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact is that COMPARISON_OPERATORS depends on some functions, so they should live after the function definitions.
And I would prefer LOGICAL_OPERATORS to be close.

81 changes: 81 additions & 0 deletions integrations/pinecone/tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import List

import pytest
from haystack.dataclasses.document import Document
from haystack.testing.document_store import (
FilterDocumentsTest,
)


class TestFilters(FilterDocumentsTest):
def assert_documents_are_equal(self, received: List[Document], expected: List[Document]):
for doc in received:
# Pinecone seems to convert strings to datetime objects (undocumented behavior)
# We convert them back to strings to compare them
if "date" in doc.meta:
doc.meta["date"] = doc.meta["date"].isoformat()
# Pinecone seems to convert integers to floats (undocumented behavior)
# We convert them back to integers to compare them
if "number" in doc.meta:
doc.meta["number"] = int(doc.meta["number"])

# Lists comparison
assert len(received) == len(expected)
received.sort(key=lambda x: x.id)
expected.sort(key=lambda x: x.id)
for received_doc, expected_doc in zip(received, expected):
assert received_doc.meta == expected_doc.meta
assert received_doc.content == expected_doc.content
if received_doc.dataframe is None:
assert expected_doc.dataframe is None
else:
assert received_doc.dataframe.equals(expected_doc.dataframe)
# unfortunately, Pinecone returns a slightly different embedding
if received_doc.embedding is None:
assert expected_doc.embedding is None
else:
assert received_doc.embedding == pytest.approx(expected_doc.embedding)

@pytest.mark.skip(reason="Pinecone does not support comparison with null values")
def test_comparison_equal_with_none(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with null values")
def test_comparison_not_equal_with_none(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with dates")
def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with null values")
def test_comparison_greater_than_with_none(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with dates")
def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with null values")
def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with dates")
def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with null values")
def test_comparison_less_than_with_none(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with dates")
def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support comparison with null values")
def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs):
...

@pytest.mark.skip(reason="Pinecone does not support the 'not' operator")
def test_not_operator(self, document_store, filterable_docs):
...