Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pinecone - dense retriever #145

Merged
merged 6 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions integrations/pinecone/src/pinecone_haystack/dense_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document

from pinecone_haystack.document_store import PineconeDocumentStore


@component
class PineconeDenseRetriever:
"""
Retrieves documents from the PineconeDocumentStore, based on their dense embeddings.

Needs to be connected to the PineconeDocumentStore.
"""

def __init__(
self,
*,
document_store: PineconeDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
):
"""
Create the PineconeDenseRetriever component.

:param document_store: An instance of PineconeDocumentStore.
:param filters: Filters applied to the retrieved Documents. Defaults to None.
:param top_k: Maximum number of Documents to return, defaults to 10.

:raises ValueError: If `document_store` is not an instance of PineconeDocumentStore.
"""
if not isinstance(document_store, PineconeDocumentStore):
msg = "document_store must be an instance of PineconeDocumentStore"
raise ValueError(msg)

self.document_store = document_store
self.filters = filters or {}
self.top_k = top_k

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
filters=self.filters,
top_k=self.top_k,
document_store=self.document_store.to_dict(),
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "PineconeDenseRetriever":
data["init_parameters"]["document_store"] = default_from_dict(
PineconeDocumentStore, data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query_embedding: List[float]):
"""
Retrieve documents from the PineconeDocumentStore, based on their dense embeddings.

:param query_embedding: Embedding of the query.
:return: List of Document similar to `query_embedding`.
"""
docs = self.document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=self.filters,
top_k=self.top_k,
)
return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from haystack import default_to_dict
from haystack.dataclasses import Document
from haystack.document_stores import DuplicatePolicy
from haystack.utils.filters import convert

from pinecone_haystack.filters import _normalize_filters

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -178,7 +181,7 @@ def _embedding_retrieval(
query_embedding: List[float],
*,
namespace: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None, # noqa: ARG002 (filters to be implemented)
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
) -> List[Document]:
"""
Expand All @@ -200,10 +203,15 @@ def _embedding_retrieval(
msg = "query_embedding must be a non-empty list of floats"
raise ValueError(msg)

if filters and "operator" not in filters and "conditions" not in filters:
filters = convert(filters)
filters = _normalize_filters(filters) if filters else None

result = self._index.query(
vector=query_embedding,
top_k=top_k,
namespace=namespace or self.namespace,
filter=filters,
include_values=True,
include_metadata=True,
)
Expand Down
187 changes: 187 additions & 0 deletions integrations/pinecone/src/pinecone_haystack/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict

from haystack.errors import FilterError
from pandas import DataFrame


def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]:
"""
Converts Haystack filters in Pinecone compatible filters.
Reference: https://docs.pinecone.io/docs/metadata-filtering
"""
if not isinstance(filters, dict):
msg = "Filters must be a dictionary"
raise FilterError(msg)

if "field" in filters:
return _parse_comparison_condition(filters)
return _parse_logical_condition(filters)


def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "conditions" not in condition:
msg = f"'conditions' key missing in {condition}"
raise FilterError(msg)

operator = condition["operator"]
conditions = [_parse_comparison_condition(c) for c in condition["conditions"]]

if operator in LOGICAL_OPERATORS:
return {LOGICAL_OPERATORS[operator]: conditions}

msg = f"Unknown logical operator '{operator}'"
raise FilterError(msg)


def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
if "field" not in condition:
# 'field' key is only found in comparison dictionaries.
# We assume this is a logic dictionary since it's not present.
return _parse_logical_condition(condition)

field: str = condition["field"]
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "value" not in condition:
msg = f"'value' key missing in {condition}"
raise FilterError(msg)
operator: str = condition["operator"]
if field.startswith("meta."):
# Remove the "meta." prefix if present.
# Documents are flattened when using the PineconeDocumentStore
# so we don't need to specify the "meta." prefix.
# Instead of raising an error we handle it gracefully.
field = field[5:]

value: Any = condition["value"]
if isinstance(value, DataFrame):
value = value.to_json()

return COMPARISON_OPERATORS[operator](field, value)


def _equal(field: str, value: Any) -> Dict[str, Any]:
supported_types = (str, int, float, bool)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'equal' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$eq": value}}


def _not_equal(field: str, value: Any) -> Dict[str, Any]:
supported_types = (str, int, float, bool)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'not equal' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$ne": value}}


def _greater_than(field: str, value: Any) -> Dict[str, Any]:
supported_types = (int, float)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'greater than' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$gt": value}}


def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]:
supported_types = (int, float)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'greater than equal' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$gte": value}}


def _less_than(field: str, value: Any) -> Dict[str, Any]:
supported_types = (int, float)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'less than' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$lt": value}}


def _less_than_equal(field: str, value: Any) -> Dict[str, Any]:
supported_types = (int, float)
if not isinstance(value, supported_types):
msg = (
f"Unsupported type for 'less than equal' comparison: {type(value)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$lte": value}}


def _not_in(field: str, value: Any) -> Dict[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone"
raise FilterError(msg)

supported_types = (int, float, str)
for v in value:
if not isinstance(v, supported_types):
msg = (
f"Unsupported type for 'not in' comparison: {type(v)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$nin": value}}


def _in(field: str, value: Any) -> Dict[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone"
raise FilterError(msg)

supported_types = (int, float, str)
for v in value:
if not isinstance(v, supported_types):
msg = (
f"Unsupported type for 'in' comparison: {type(v)}. "
f"Types supported by Pinecone are: {supported_types}"
)
raise FilterError(msg)

return {field: {"$in": value}}


COMPARISON_OPERATORS = {
"==": _equal,
"!=": _not_equal,
">": _greater_than,
">=": _greater_than_equal,
"<": _less_than,
"<=": _less_than_equal,
"in": _in,
"not in": _not_in,
}

LOGICAL_OPERATORS = {"AND": "$and", "OR": "$or"}
100 changes: 100 additions & 0 deletions integrations/pinecone/tests/test_dense_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock, patch

from haystack.dataclasses import Document

from pinecone_haystack.dense_retriever import PineconeDenseRetriever
from pinecone_haystack.document_store import PineconeDocumentStore


def test_init_default():
mock_store = Mock(spec=PineconeDocumentStore)
retriever = PineconeDenseRetriever(document_store=mock_store)
assert retriever.document_store == mock_store
assert retriever.filters == {}
assert retriever.top_k == 10


@patch("pinecone_haystack.document_store.pinecone")
def test_to_dict(mock_pinecone):
mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 512}
document_store = PineconeDocumentStore(
api_key="test-key",
environment="gcp-starter",
index="default",
namespace="test-namespace",
batch_size=50,
dimension=512,
)
retriever = PineconeDenseRetriever(document_store=document_store)
res = retriever.to_dict()
assert res == {
"type": "pinecone_haystack.dense_retriever.PineconeDenseRetriever",
"init_parameters": {
"document_store": {
"init_parameters": {
"environment": "gcp-starter",
"index": "default",
"namespace": "test-namespace",
"batch_size": 50,
"dimension": 512,
},
"type": "pinecone_haystack.document_store.PineconeDocumentStore",
},
"filters": {},
"top_k": 10,
},
}


@patch("pinecone_haystack.document_store.pinecone")
def test_from_dict(mock_pinecone, monkeypatch):
data = {
"type": "pinecone_haystack.dense_retriever.PineconeDenseRetriever",
"init_parameters": {
"document_store": {
"init_parameters": {
"environment": "gcp-starter",
"index": "default",
"namespace": "test-namespace",
"batch_size": 50,
"dimension": 512,
},
"type": "pinecone_haystack.document_store.PineconeDocumentStore",
},
"filters": {},
"top_k": 10,
},
}

mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 512}
monkeypatch.setenv("PINECONE_API_KEY", "test-key")
retriever = PineconeDenseRetriever.from_dict(data)

document_store = retriever.document_store
assert document_store.environment == "gcp-starter"
assert document_store.index == "default"
assert document_store.namespace == "test-namespace"
assert document_store.batch_size == 50
assert document_store.dimension == 512

assert retriever.filters == {}
assert retriever.top_k == 10


def test_run():
mock_store = Mock(spec=PineconeDocumentStore)
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])]
retriever = PineconeDenseRetriever(document_store=mock_store)
res = retriever.run(query_embedding=[0.5, 0.7])
mock_store._embedding_retrieval.assert_called_once_with(
query_embedding=[0.5, 0.7],
filters={},
top_k=10,
)
assert len(res) == 1
assert len(res["documents"]) == 1
assert res["documents"][0].content == "Test doc"
assert res["documents"][0].embedding == [0.1, 0.2]
Loading