From 54df1f51daf23bc95e035eb625b1bb98fd5f218e Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 12 Feb 2024 18:26:02 +0100 Subject: [PATCH] rename astraretriever --- integrations/astra/README.md | 8 ++++---- integrations/astra/examples/example.py | 4 ++-- integrations/astra/examples/pipeline_example.py | 4 ++-- .../components/retrievers/astra/__init__.py | 4 ++-- .../components/retrievers/astra/retriever.py | 6 +++--- integrations/astra/tests/test_retriever.py | 10 +++++----- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/integrations/astra/README.md b/integrations/astra/README.md index 9c3b264cb..75fdfeb6d 100644 --- a/integrations/astra/README.md +++ b/integrations/astra/README.md @@ -42,7 +42,7 @@ or ## Usage -This package includes Astra Document Store and Astra Retriever classes that integrate with Haystack, allowing you to easily perform document retrieval or RAG with Astra, and include those functions in Haystack pipelines. +This package includes Astra Document Store and Astra Embedding Retriever classes that integrate with Haystack, allowing you to easily perform document retrieval or RAG with Astra, and include those functions in Haystack pipelines. ### In order to use the Document Store directly: @@ -78,7 +78,7 @@ document_store = AstraDocumentStore( Then you can use the document store functions like count_document below: `document_store.count_documents()` -### Using the Astra Retriever with Haystack Pipelines +### Using the Astra Embedding Retriever with Haystack Pipelines Create the Document Store object like above, then import and create the Pipeline: @@ -87,8 +87,8 @@ from haystack.preview import Pipeline pipeline = Pipeline() ``` -Add your AstraRetriever into the pipeline -`pipeline.add_component(instance=AstraSingleRetriever(document_store=document_store), name="retriever")` +Add your AstraEmbeddingRetriever into the pipeline +`pipeline.add_component(instance=AstraEmbeddingRetriever(document_store=document_store), name="retriever")` Add other components and connect them as desired. Then run your pipeline: `pipeline.run(...)` diff --git a/integrations/astra/examples/example.py b/integrations/astra/examples/example.py index 35963868c..8ecb2eef0 100644 --- a/integrations/astra/examples/example.py +++ b/integrations/astra/examples/example.py @@ -10,7 +10,7 @@ from haystack.components.writers import DocumentWriter from haystack.document_stores.types import DuplicatePolicy -from haystack_integrations.components.retrievers.astra import AstraRetriever +from haystack_integrations.components.retrievers.astra import AstraEmbeddingRetriever from haystack_integrations.document_stores.astra import AstraDocumentStore logger = logging.getLogger(__name__) @@ -66,7 +66,7 @@ instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), name="embedder", ) -q.add_component("retriever", AstraRetriever(document_store)) +q.add_component("retriever", AstraEmbeddingRetriever(document_store)) q.connect("embedder", "retriever") diff --git a/integrations/astra/examples/pipeline_example.py b/integrations/astra/examples/pipeline_example.py index cacb1eb9f..ac87488d9 100644 --- a/integrations/astra/examples/pipeline_example.py +++ b/integrations/astra/examples/pipeline_example.py @@ -9,7 +9,7 @@ from haystack.components.writers import DocumentWriter from haystack.document_stores.types import DuplicatePolicy -from haystack_integrations.components.retrievers.astra import AstraRetriever +from haystack_integrations.components.retrievers.astra import AstraEmbeddingRetriever from haystack_integrations.document_stores.astra import AstraDocumentStore logger = logging.getLogger(__name__) @@ -77,7 +77,7 @@ instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), name="embedder", ) -rag_pipeline.add_component(instance=AstraRetriever(document_store=document_store), name="retriever") +rag_pipeline.add_component(instance=AstraEmbeddingRetriever(document_store=document_store), name="retriever") rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") rag_pipeline.add_component(instance=OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm") rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") diff --git a/integrations/astra/src/haystack_integrations/components/retrievers/astra/__init__.py b/integrations/astra/src/haystack_integrations/components/retrievers/astra/__init__.py index 33ef6d15e..ed4bfe40e 100644 --- a/integrations/astra/src/haystack_integrations/components/retrievers/astra/__init__.py +++ b/integrations/astra/src/haystack_integrations/components/retrievers/astra/__init__.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2023-present Anant Corporation # # SPDX-License-Identifier: Apache-2.0 -from .retriever import AstraRetriever +from .retriever import AstraEmbeddingRetriever -__all__ = ["AstraRetriever"] +__all__ = ["AstraEmbeddingRetriever"] diff --git a/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py index fdf9b0722..7236b5749 100644 --- a/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py +++ b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py @@ -10,14 +10,14 @@ @component -class AstraRetriever: +class AstraEmbeddingRetriever: """ A component for retrieving documents from an AstraDocumentStore. """ def __init__(self, document_store: AstraDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10): """ - Create an AstraRetriever component. Usually you pass some basic configuration + Create an AstraEmbeddingRetriever component. Usually you pass some basic configuration parameters to the constructor. :param filters: A dictionary with filters to narrow down the search space (default is None). @@ -59,7 +59,7 @@ def to_dict(self) -> Dict[str, Any]: ) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "AstraRetriever": + def from_dict(cls, data: Dict[str, Any]) -> "AstraEmbeddingRetriever": document_store = AstraDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store return default_from_dict(cls, data) diff --git a/integrations/astra/tests/test_retriever.py b/integrations/astra/tests/test_retriever.py index eb9260590..c06a52edb 100644 --- a/integrations/astra/tests/test_retriever.py +++ b/integrations/astra/tests/test_retriever.py @@ -5,7 +5,7 @@ import pytest -from haystack_integrations.components.retrievers.astra import AstraRetriever +from haystack_integrations.components.retrievers.astra import AstraEmbeddingRetriever @pytest.mark.skipif( @@ -14,9 +14,9 @@ @pytest.mark.skipif(os.environ.get("ASTRA_DB_ID", "") == "", reason="ASTRA_DB_ID is not set") @pytest.mark.integration def test_retriever_to_json(document_store): - retriever = AstraRetriever(document_store, filters={"foo": "bar"}, top_k=99) + retriever = AstraEmbeddingRetriever(document_store, filters={"foo": "bar"}, top_k=99) assert retriever.to_dict() == { - "type": "haystack_integrations.components.retrievers.astra.retriever.AstraRetriever", + "type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever", "init_parameters": { "filters": {"foo": "bar"}, "top_k": 99, @@ -43,7 +43,7 @@ def test_retriever_to_json(document_store): @pytest.mark.integration def test_retriever_from_json(): data = { - "type": "haystack_integrations.components.retrievers.astra.retriever.AstraRetriever", + "type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever", "init_parameters": { "filters": {"bar": "baz"}, "top_k": 42, @@ -62,6 +62,6 @@ def test_retriever_from_json(): }, }, } - retriever = AstraRetriever.from_dict(data) + retriever = AstraEmbeddingRetriever.from_dict(data) assert retriever.top_k == 42 assert retriever.filters == {"bar": "baz"}