Skip to content

Commit

Permalink
in/not_in filters
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Dec 21, 2023
1 parent 69bee3d commit c6600e9
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
39 changes: 37 additions & 2 deletions integrations/pinecone/src/pinecone_haystack/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]:

return {field: {"$gte": value}}


def _less_than(field: str, value: Any) -> Dict[str, Any]:
supported_types = (int, float)
if not isinstance(value, supported_types):
Expand All @@ -129,13 +130,47 @@ def _less_than_equal(field: str, value: Any) -> Dict[str, Any]:
return {field: {"$lte": value}}


def _not_in(field: str, value: Any) -> Dict[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone"
raise FilterError(msg)

supported_types = (int, float, str)
for v in value:
if not isinstance(v, supported_types):
msg = (
f"Unsupported type for 'not in' comparison: {type(v)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$nin": value}}


def _in(field: str, value: Any) -> Dict[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone"
raise FilterError(msg)

supported_types = (int, float, str)
for v in value:
if not isinstance(v, supported_types):
msg = (
f"Unsupported type for 'in' comparison: {type(v)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$in": value}}


COMPARISON_OPERATORS = {
"==": _equal,
"!=": _not_equal,
">": _greater_than,
">=": _greater_than_equal,
"<": _less_than,
"<=": _less_than_equal,
# "in": _in,
# "not in": _not_in,
"in": _in,
"not in": _not_in,
}
2 changes: 0 additions & 2 deletions integrations/pinecone/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do
if "number" in doc.meta:
doc.meta["number"] = int(doc.meta["number"])


# let's compare the documents
assert len(received) == len(expected)
for received_doc in received:
Expand Down Expand Up @@ -83,4 +82,3 @@ def test_comparison_less_than_equal_with_iso_date(self, document_store, filterab
@pytest.mark.skip(reason="Pinecone does not support comparison with null values")
def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs):
...

0 comments on commit c6600e9

Please sign in to comment.