diff --git a/integrations/pinecone/src/pinecone_haystack/dense_retriever.py b/integrations/pinecone/src/pinecone_haystack/dense_retriever.py new file mode 100644 index 000000000..3f60f252b --- /dev/null +++ b/integrations/pinecone/src/pinecone_haystack/dense_retriever.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Optional + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document + +from pinecone_haystack.document_store import PineconeDocumentStore + + +@component +class PineconeDenseRetriever: + """ + Retrieves documents from the PineconeDocumentStore, based on their dense embeddings. + + Needs to be connected to the PineconeDocumentStore. + """ + + def __init__( + self, + *, + document_store: PineconeDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + ): + """ + Create the PineconeDenseRetriever component. + + :param document_store: An instance of PineconeDocumentStore. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + :param top_k: Maximum number of Documents to return, defaults to 10. + + :raises ValueError: If `document_store` is not an instance of PineconeDocumentStore. + """ + if not isinstance(document_store, PineconeDocumentStore): + msg = "document_store must be an instance of PineconeDocumentStore" + raise ValueError(msg) + + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PineconeDenseRetriever": + data["init_parameters"]["document_store"] = default_from_dict( + PineconeDocumentStore, data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query_embedding: List[float]): + """ + Retrieve documents from the PineconeDocumentStore, based on their dense embeddings. + + :param query_embedding: Embedding of the query. + :return: List of Document similar to `query_embedding`. + """ + docs = self.document_store._embedding_retrieval( + query_embedding=query_embedding, + filters=self.filters, + top_k=self.top_k, + ) + return {"documents": docs} diff --git a/integrations/pinecone/tests/test_dense_retriever.py b/integrations/pinecone/tests/test_dense_retriever.py new file mode 100644 index 000000000..ceb73b687 --- /dev/null +++ b/integrations/pinecone/tests/test_dense_retriever.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import Mock, patch + +from haystack.dataclasses import Document + +from pinecone_haystack.dense_retriever import PineconeDenseRetriever +from pinecone_haystack.document_store import PineconeDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=PineconeDocumentStore) + retriever = PineconeDenseRetriever(document_store=mock_store) + assert retriever.document_store == mock_store + assert retriever.filters == {} + assert retriever.top_k == 10 + + +@patch("pinecone_haystack.document_store.pinecone") +def test_to_dict(mock_pinecone): + mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 512} + document_store = PineconeDocumentStore( + api_key="test-key", + environment="gcp-starter", + index="default", + namespace="test-namespace", + batch_size=50, + dimension=512, + ) + retriever = PineconeDenseRetriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "pinecone_haystack.dense_retriever.PineconeDenseRetriever", + "init_parameters": { + "document_store": { + "init_parameters": { + "environment": "gcp-starter", + "index": "default", + "namespace": "test-namespace", + "batch_size": 50, + "dimension": 512, + }, + "type": "pinecone_haystack.document_store.PineconeDocumentStore", + }, + "filters": {}, + "top_k": 10, + }, + } + + +@patch("pinecone_haystack.document_store.pinecone") +def test_from_dict(mock_pinecone, monkeypatch): + data = { + "type": "pinecone_haystack.dense_retriever.PineconeDenseRetriever", + "init_parameters": { + "document_store": { + "init_parameters": { + "environment": "gcp-starter", + "index": "default", + "namespace": "test-namespace", + "batch_size": 50, + "dimension": 512, + }, + "type": "pinecone_haystack.document_store.PineconeDocumentStore", + }, + "filters": {}, + "top_k": 10, + }, + } + + mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 512} + monkeypatch.setenv("PINECONE_API_KEY", "test-key") + retriever = PineconeDenseRetriever.from_dict(data) + + document_store = retriever.document_store + assert document_store.environment == "gcp-starter" + assert document_store.index == "default" + assert document_store.namespace == "test-namespace" + assert document_store.batch_size == 50 + assert document_store.dimension == 512 + + assert retriever.filters == {} + assert retriever.top_k == 10 + + +def test_run(): + mock_store = Mock(spec=PineconeDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = PineconeDenseRetriever(document_store=mock_store) + res = retriever.run(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters={}, + top_k=10, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2]