From 435da6a5c59131c9e9090c2f263a87e993427224 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 22 Dec 2023 16:05:49 +0100 Subject: [PATCH 1/5] filters --- .../src/pinecone_haystack/document_store.py | 10 +- .../pinecone/src/pinecone_haystack/filters.py | 193 ++++++++++++++++++ integrations/pinecone/tests/test_filters.py | 81 ++++++++ 3 files changed, 283 insertions(+), 1 deletion(-) create mode 100644 integrations/pinecone/src/pinecone_haystack/filters.py create mode 100644 integrations/pinecone/tests/test_filters.py diff --git a/integrations/pinecone/src/pinecone_haystack/document_store.py b/integrations/pinecone/src/pinecone_haystack/document_store.py index 576993de6..d6296e030 100644 --- a/integrations/pinecone/src/pinecone_haystack/document_store.py +++ b/integrations/pinecone/src/pinecone_haystack/document_store.py @@ -12,6 +12,9 @@ from haystack import default_to_dict from haystack.dataclasses import Document from haystack.document_stores import DuplicatePolicy +from haystack.utils.filters import convert + +from pinecone_haystack.filters import _normalize_filters logger = logging.getLogger(__name__) @@ -178,7 +181,7 @@ def _embedding_retrieval( query_embedding: List[float], *, namespace: Optional[str] = None, - filters: Optional[Dict[str, Any]] = None, # noqa: ARG002 (filters to be implemented) + filters: Optional[Dict[str, Any]] = None, top_k: int = 10, ) -> List[Document]: """ @@ -200,10 +203,15 @@ def _embedding_retrieval( msg = "query_embedding must be a non-empty list of floats" raise ValueError(msg) + if filters and "operator" not in filters and "conditions" not in filters: + filters = convert(filters) + filters = _normalize_filters(filters) if filters else None + result = self._index.query( vector=query_embedding, top_k=top_k, namespace=namespace or self.namespace, + filter=filters, include_values=True, include_metadata=True, ) diff --git a/integrations/pinecone/src/pinecone_haystack/filters.py b/integrations/pinecone/src/pinecone_haystack/filters.py new file mode 100644 index 000000000..805162609 --- /dev/null +++ b/integrations/pinecone/src/pinecone_haystack/filters.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict + +from haystack.errors import FilterError +from pandas import DataFrame + + +def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]: + """ + Converts Haystack filters in Pinecone compatible filters. + Reference: https://docs.pinecone.io/docs/metadata-filtering + """ + if not isinstance(filters, dict): + msg = "Filters must be a dictionary" + raise FilterError(msg) + + if "field" in filters: + return _parse_comparison_condition(filters) + return _parse_logical_condition(filters) + + +def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "conditions" not in condition: + msg = f"'conditions' key missing in {condition}" + raise FilterError(msg) + + operator = condition["operator"] + conditions = [_parse_comparison_condition(c) for c in condition["conditions"]] + + if operator in LOGICAL_OPERATORS: + return {LOGICAL_OPERATORS[operator]: conditions} + + msg = f"Unknown logical operator '{operator}'" + raise FilterError(msg) + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: + if "field" not in condition: + # 'field' key is only found in comparison dictionaries. + # We assume this is a logic dictionary since it's not present. + return _parse_logical_condition(condition) + + field: str = condition["field"] + + if field.startswith("meta."): + # Remove the "meta." prefix if present. + # Documents are flattened when using the PineconeDocumentStore + # so we don't need to specify the "meta." prefix. + # Instead of raising an error we handle it gracefully. + field = field[5:] + + # if field == "content": + # field = "meta.content" + # if field == "dataframe": + # field = "meta.dataframe" + + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "value" not in condition: + msg = f"'value' key missing in {condition}" + raise FilterError(msg) + operator: str = condition["operator"] + value: Any = condition["value"] + if isinstance(value, DataFrame): + value = value.to_json() + + return COMPARISON_OPERATORS[operator](field, value) + + +def _equal(field: str, value: Any) -> Dict[str, Any]: + supported_types = (str, int, float, bool) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'equal' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$eq": value}} + + +def _not_equal(field: str, value: Any) -> Dict[str, Any]: + supported_types = (str, int, float, bool) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'inequal' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$ne": value}} + + +def _greater_than(field: str, value: Any) -> Dict[str, Any]: + supported_types = (int, float) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'greater than' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$gt": value}} + + +def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: + supported_types = (int, float) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'greater than equal' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$gte": value}} + + +def _less_than(field: str, value: Any) -> Dict[str, Any]: + supported_types = (int, float) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'less than' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$lt": value}} + + +def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: + supported_types = (int, float) + if not isinstance(value, supported_types): + msg = ( + f"Unsupported type for 'less than equal' comparison: {type(value)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$lte": value}} + + +def _not_in(field: str, value: Any) -> Dict[str, Any]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" + raise FilterError(msg) + + supported_types = (int, float, str) + for v in value: + if not isinstance(v, supported_types): + msg = ( + f"Unsupported type for 'not in' comparison: {type(v)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$nin": value}} + + +def _in(field: str, value: Any) -> Dict[str, Any]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone" + raise FilterError(msg) + + supported_types = (int, float, str) + for v in value: + if not isinstance(v, supported_types): + msg = ( + f"Unsupported type for 'in' comparison: {type(v)}. " + f"Types supported by Pinecone are: {supported_types}" + ) + raise FilterError(msg) + + return {field: {"$in": value}} + + +COMPARISON_OPERATORS = { + "==": _equal, + "!=": _not_equal, + ">": _greater_than, + ">=": _greater_than_equal, + "<": _less_than, + "<=": _less_than_equal, + "in": _in, + "not in": _not_in, +} + +LOGICAL_OPERATORS = {"AND": "$and", "OR": "$or"} diff --git a/integrations/pinecone/tests/test_filters.py b/integrations/pinecone/tests/test_filters.py new file mode 100644 index 000000000..1e6aeb0cd --- /dev/null +++ b/integrations/pinecone/tests/test_filters.py @@ -0,0 +1,81 @@ +from typing import List + +import pytest +from haystack.dataclasses.document import Document +from haystack.testing.document_store import ( + FilterDocumentsTest, +) + + +class TestFilters(FilterDocumentsTest): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + for doc in received: + # Pinecone seems to convert strings to datetime objects (undocumented behavior) + # We convert them back to strings to compare them + if "date" in doc.meta: + doc.meta["date"] = doc.meta["date"].isoformat() + # Pinecone seems to convert integers to floats (undocumented behavior) + # We convert them back to integers to compare them + if "number" in doc.meta: + doc.meta["number"] = int(doc.meta["number"]) + + # Lists comparison + assert len(received) == len(expected) + received.sort(key=lambda x: x.id) + expected.sort(key=lambda x: x.id) + for received_doc, expected_doc in zip(received, expected): + assert received_doc.meta == expected_doc.meta + assert received_doc.content == expected_doc.content + if received_doc.dataframe is None: + assert expected_doc.dataframe is None + else: + assert received_doc.dataframe.equals(expected_doc.dataframe) + # unfortunately, Pinecone returns a slightly different embedding + if received_doc.embedding is None: + assert expected_doc.embedding is None + else: + assert received_doc.embedding == pytest.approx(expected_doc.embedding) + + @pytest.mark.skip(reason="Pinecone does not support comparison with null values") + def test_comparison_equal_with_none(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with null values") + def test_comparison_not_equal_with_none(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with dates") + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with null values") + def test_comparison_greater_than_with_none(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with dates") + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with null values") + def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with dates") + def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with null values") + def test_comparison_less_than_with_none(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with dates") + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support comparison with null values") + def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Pinecone does not support the 'not' operator") + def test_not_operator(self, document_store, filterable_docs): + ... From 4b3fcfe992d3b84e88b11e98062f9fb1672124c1 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 22 Dec 2023 16:22:27 +0100 Subject: [PATCH 2/5] Update integrations/pinecone/src/pinecone_haystack/filters.py Co-authored-by: Massimiliano Pippi --- integrations/pinecone/src/pinecone_haystack/filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/pinecone/src/pinecone_haystack/filters.py b/integrations/pinecone/src/pinecone_haystack/filters.py index 805162609..a84088579 100644 --- a/integrations/pinecone/src/pinecone_haystack/filters.py +++ b/integrations/pinecone/src/pinecone_haystack/filters.py @@ -89,7 +89,7 @@ def _not_equal(field: str, value: Any) -> Dict[str, Any]: supported_types = (str, int, float, bool) if not isinstance(value, supported_types): msg = ( - f"Unsupported type for 'inequal' comparison: {type(value)}. " + f"Unsupported type for 'not equal' comparison: {type(value)}. " f"Types supported by Pinecone are: {supported_types}" ) raise FilterError(msg) From 0c8d62741b93e9fda748212f2c4de1a86b6e0f26 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 22 Dec 2023 16:26:20 +0100 Subject: [PATCH 3/5] improv from PR review --- .../pinecone/src/pinecone_haystack/filters.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/integrations/pinecone/src/pinecone_haystack/filters.py b/integrations/pinecone/src/pinecone_haystack/filters.py index 805162609..75c43585e 100644 --- a/integrations/pinecone/src/pinecone_haystack/filters.py +++ b/integrations/pinecone/src/pinecone_haystack/filters.py @@ -46,19 +46,6 @@ def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: return _parse_logical_condition(condition) field: str = condition["field"] - - if field.startswith("meta."): - # Remove the "meta." prefix if present. - # Documents are flattened when using the PineconeDocumentStore - # so we don't need to specify the "meta." prefix. - # Instead of raising an error we handle it gracefully. - field = field[5:] - - # if field == "content": - # field = "meta.content" - # if field == "dataframe": - # field = "meta.dataframe" - if "operator" not in condition: msg = f"'operator' key missing in {condition}" raise FilterError(msg) @@ -66,9 +53,16 @@ def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: msg = f"'value' key missing in {condition}" raise FilterError(msg) operator: str = condition["operator"] + if field.startswith("meta."): + # Remove the "meta." prefix if present. + # Documents are flattened when using the PineconeDocumentStore + # so we don't need to specify the "meta." prefix. + # Instead of raising an error we handle it gracefully. + field = field[5:] + value: Any = condition["value"] if isinstance(value, DataFrame): - value = value.to_json() + value = value.to_json() return COMPARISON_OPERATORS[operator](field, value) From a8c305eee74b1c1272839b65c3b779aed91e9176 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 22 Dec 2023 16:27:58 +0100 Subject: [PATCH 4/5] fmt --- integrations/pinecone/src/pinecone_haystack/filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/pinecone/src/pinecone_haystack/filters.py b/integrations/pinecone/src/pinecone_haystack/filters.py index c880867de..2ddb26d61 100644 --- a/integrations/pinecone/src/pinecone_haystack/filters.py +++ b/integrations/pinecone/src/pinecone_haystack/filters.py @@ -62,7 +62,7 @@ def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: value: Any = condition["value"] if isinstance(value, DataFrame): - value = value.to_json() + value = value.to_json() return COMPARISON_OPERATORS[operator](field, value) From 8e94070c366ef801cf3b01431bbc94b280ac1bfe Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 22 Dec 2023 16:53:03 +0100 Subject: [PATCH 5/5] dense retriever! --- .../src/pinecone_haystack/dense_retriever.py | 72 +++++++++++++ .../pinecone/tests/test_dense_retriever.py | 100 ++++++++++++++++++ 2 files changed, 172 insertions(+) create mode 100644 integrations/pinecone/src/pinecone_haystack/dense_retriever.py create mode 100644 integrations/pinecone/tests/test_dense_retriever.py diff --git a/integrations/pinecone/src/pinecone_haystack/dense_retriever.py b/integrations/pinecone/src/pinecone_haystack/dense_retriever.py new file mode 100644 index 000000000..3f60f252b --- /dev/null +++ b/integrations/pinecone/src/pinecone_haystack/dense_retriever.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Optional + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document + +from pinecone_haystack.document_store import PineconeDocumentStore + + +@component +class PineconeDenseRetriever: + """ + Retrieves documents from the PineconeDocumentStore, based on their dense embeddings. + + Needs to be connected to the PineconeDocumentStore. + """ + + def __init__( + self, + *, + document_store: PineconeDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + ): + """ + Create the PineconeDenseRetriever component. + + :param document_store: An instance of PineconeDocumentStore. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + :param top_k: Maximum number of Documents to return, defaults to 10. + + :raises ValueError: If `document_store` is not an instance of PineconeDocumentStore. + """ + if not isinstance(document_store, PineconeDocumentStore): + msg = "document_store must be an instance of PineconeDocumentStore" + raise ValueError(msg) + + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PineconeDenseRetriever": + data["init_parameters"]["document_store"] = default_from_dict( + PineconeDocumentStore, data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query_embedding: List[float]): + """ + Retrieve documents from the PineconeDocumentStore, based on their dense embeddings. + + :param query_embedding: Embedding of the query. + :return: List of Document similar to `query_embedding`. + """ + docs = self.document_store._embedding_retrieval( + query_embedding=query_embedding, + filters=self.filters, + top_k=self.top_k, + ) + return {"documents": docs} diff --git a/integrations/pinecone/tests/test_dense_retriever.py b/integrations/pinecone/tests/test_dense_retriever.py new file mode 100644 index 000000000..ceb73b687 --- /dev/null +++ b/integrations/pinecone/tests/test_dense_retriever.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import Mock, patch + +from haystack.dataclasses import Document + +from pinecone_haystack.dense_retriever import PineconeDenseRetriever +from pinecone_haystack.document_store import PineconeDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=PineconeDocumentStore) + retriever = PineconeDenseRetriever(document_store=mock_store) + assert retriever.document_store == mock_store + assert retriever.filters == {} + assert retriever.top_k == 10 + + +@patch("pinecone_haystack.document_store.pinecone") +def test_to_dict(mock_pinecone): + mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 512} + document_store = PineconeDocumentStore( + api_key="test-key", + environment="gcp-starter", + index="default", + namespace="test-namespace", + batch_size=50, + dimension=512, + ) + retriever = PineconeDenseRetriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "pinecone_haystack.dense_retriever.PineconeDenseRetriever", + "init_parameters": { + "document_store": { + "init_parameters": { + "environment": "gcp-starter", + "index": "default", + "namespace": "test-namespace", + "batch_size": 50, + "dimension": 512, + }, + "type": "pinecone_haystack.document_store.PineconeDocumentStore", + }, + "filters": {}, + "top_k": 10, + }, + } + + +@patch("pinecone_haystack.document_store.pinecone") +def test_from_dict(mock_pinecone, monkeypatch): + data = { + "type": "pinecone_haystack.dense_retriever.PineconeDenseRetriever", + "init_parameters": { + "document_store": { + "init_parameters": { + "environment": "gcp-starter", + "index": "default", + "namespace": "test-namespace", + "batch_size": 50, + "dimension": 512, + }, + "type": "pinecone_haystack.document_store.PineconeDocumentStore", + }, + "filters": {}, + "top_k": 10, + }, + } + + mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 512} + monkeypatch.setenv("PINECONE_API_KEY", "test-key") + retriever = PineconeDenseRetriever.from_dict(data) + + document_store = retriever.document_store + assert document_store.environment == "gcp-starter" + assert document_store.index == "default" + assert document_store.namespace == "test-namespace" + assert document_store.batch_size == 50 + assert document_store.dimension == 512 + + assert retriever.filters == {} + assert retriever.top_k == 10 + + +def test_run(): + mock_store = Mock(spec=PineconeDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = PineconeDenseRetriever(document_store=mock_store) + res = retriever.run(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters={}, + top_k=10, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2]