diff --git a/integrations/chroma/example/example.py b/integrations/chroma/example/example.py index 1e6a7e402..d9b0f414e 100644 --- a/integrations/chroma/example/example.py +++ b/integrations/chroma/example/example.py @@ -7,7 +7,7 @@ from haystack.components.writers import DocumentWriter from haystack_integrations.document_stores.chroma import ChromaDocumentStore -from haystack_integrations.components.retrievers.chroma import ChromaQueryRetriever +from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever HERE = Path(__file__).resolve().parent file_paths = [HERE / "data" / Path(name) for name in os.listdir("data")] @@ -22,7 +22,7 @@ indexing.run({"converter": {"sources": file_paths}}) querying = Pipeline() -querying.add_component("retriever", ChromaQueryRetriever(document_store)) +querying.add_component("retriever", ChromaQueryTextRetriever(document_store)) results = querying.run({"retriever": {"query": "Variable declarations", "top_k": 3}}) for d in results["retriever"]["documents"]: diff --git a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/__init__.py b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/__init__.py index d02300de2..53120c97c 100644 --- a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/__init__.py +++ b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/__init__.py @@ -1,3 +1,3 @@ -from .retriever import ChromaEmbeddingRetriever, ChromaQueryRetriever +from .retriever import ChromaEmbeddingRetriever, ChromaQueryTextRetriever -__all__ = ["ChromaQueryRetriever", "ChromaEmbeddingRetriever"] +__all__ = ["ChromaQueryTextRetriever", "ChromaEmbeddingRetriever"] diff --git a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py index 55ae23e64..2712f8c17 100644 --- a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py +++ b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py @@ -8,16 +8,16 @@ @component -class ChromaQueryRetriever: +class ChromaQueryTextRetriever: """ A component for retrieving documents from an ChromaDocumentStore using the `query` API. """ def __init__(self, document_store: ChromaDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10): """ - Create an ExampleRetriever component. Usually you pass some basic configuration - parameters to the constructor. + Create a ChromaQueryTextRetriever component. + :param document_store: An instance of ChromaDocumentStore. :param filters: A dictionary with filters to narrow down the search space (default is None). :param top_k: The maximum number of documents to retrieve (default is 10). """ @@ -54,14 +54,14 @@ def to_dict(self) -> Dict[str, Any]: return d @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryRetriever": + def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever": document_store = ChromaDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store return default_from_dict(cls, data) @component -class ChromaEmbeddingRetriever(ChromaQueryRetriever): +class ChromaEmbeddingRetriever(ChromaQueryTextRetriever): @component.output_types(documents=List[Document]) def run( self, diff --git a/integrations/chroma/tests/test_retriever.py b/integrations/chroma/tests/test_retriever.py index 780b4144a..88969d725 100644 --- a/integrations/chroma/tests/test_retriever.py +++ b/integrations/chroma/tests/test_retriever.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 import pytest -from haystack_integrations.components.retrievers.chroma import ChromaQueryRetriever +from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever from haystack_integrations.document_stores.chroma import ChromaDocumentStore @@ -11,9 +11,9 @@ def test_retriever_to_json(request): ds = ChromaDocumentStore( collection_name=request.node.name, embedding_function="HuggingFaceEmbeddingFunction", api_key="1234567890" ) - retriever = ChromaQueryRetriever(ds, filters={"foo": "bar"}, top_k=99) + retriever = ChromaQueryTextRetriever(ds, filters={"foo": "bar"}, top_k=99) assert retriever.to_dict() == { - "type": "haystack_integrations.components.retrievers.chroma.retriever.ChromaQueryRetriever", + "type": "haystack_integrations.components.retrievers.chroma.retriever.ChromaQueryTextRetriever", "init_parameters": { "filters": {"foo": "bar"}, "top_k": 99, @@ -29,7 +29,7 @@ def test_retriever_to_json(request): @pytest.mark.integration def test_retriever_from_json(request): data = { - "type": "haystack_integrations.components.retrievers.chroma.retriever.ChromaQueryRetriever", + "type": "haystack_integrations.components.retrievers.chroma.retriever.ChromaQueryTextRetriever", "init_parameters": { "filters": {"bar": "baz"}, "top_k": 42, @@ -40,7 +40,7 @@ def test_retriever_from_json(request): }, }, } - retriever = ChromaQueryRetriever.from_dict(data) + retriever = ChromaQueryTextRetriever.from_dict(data) assert retriever.document_store._collection_name == request.node.name assert retriever.document_store._embedding_function == "HuggingFaceEmbeddingFunction" assert retriever.document_store._embedding_function_params == {"api_key": "1234567890"}