Skip to content

Commit

Permalink
rename retriever (#407)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Feb 13, 2024
1 parent f85f838 commit 9c069c9
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 14 deletions.
4 changes: 2 additions & 2 deletions integrations/chroma/example/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from haystack.components.writers import DocumentWriter

from haystack_integrations.document_stores.chroma import ChromaDocumentStore
from haystack_integrations.components.retrievers.chroma import ChromaQueryRetriever
from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever

HERE = Path(__file__).resolve().parent
file_paths = [HERE / "data" / Path(name) for name in os.listdir("data")]
Expand All @@ -22,7 +22,7 @@
indexing.run({"converter": {"sources": file_paths}})

querying = Pipeline()
querying.add_component("retriever", ChromaQueryRetriever(document_store))
querying.add_component("retriever", ChromaQueryTextRetriever(document_store))
results = querying.run({"retriever": {"query": "Variable declarations", "top_k": 3}})

for d in results["retriever"]["documents"]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .retriever import ChromaEmbeddingRetriever, ChromaQueryRetriever
from .retriever import ChromaEmbeddingRetriever, ChromaQueryTextRetriever

__all__ = ["ChromaQueryRetriever", "ChromaEmbeddingRetriever"]
__all__ = ["ChromaQueryTextRetriever", "ChromaEmbeddingRetriever"]
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@


@component
class ChromaQueryRetriever:
class ChromaQueryTextRetriever:
"""
A component for retrieving documents from an ChromaDocumentStore using the `query` API.
"""

def __init__(self, document_store: ChromaDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10):
"""
Create an ExampleRetriever component. Usually you pass some basic configuration
parameters to the constructor.
Create a ChromaQueryTextRetriever component.
:param document_store: An instance of ChromaDocumentStore.
:param filters: A dictionary with filters to narrow down the search space (default is None).
:param top_k: The maximum number of documents to retrieve (default is 10).
"""
Expand Down Expand Up @@ -54,14 +54,14 @@ def to_dict(self) -> Dict[str, Any]:
return d

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryRetriever":
def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever":
document_store = ChromaDocumentStore.from_dict(data["init_parameters"]["document_store"])
data["init_parameters"]["document_store"] = document_store
return default_from_dict(cls, data)


@component
class ChromaEmbeddingRetriever(ChromaQueryRetriever):
class ChromaEmbeddingRetriever(ChromaQueryTextRetriever):
@component.output_types(documents=List[Document])
def run(
self,
Expand Down
10 changes: 5 additions & 5 deletions integrations/chroma/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from haystack_integrations.components.retrievers.chroma import ChromaQueryRetriever
from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever
from haystack_integrations.document_stores.chroma import ChromaDocumentStore


Expand All @@ -11,9 +11,9 @@ def test_retriever_to_json(request):
ds = ChromaDocumentStore(
collection_name=request.node.name, embedding_function="HuggingFaceEmbeddingFunction", api_key="1234567890"
)
retriever = ChromaQueryRetriever(ds, filters={"foo": "bar"}, top_k=99)
retriever = ChromaQueryTextRetriever(ds, filters={"foo": "bar"}, top_k=99)
assert retriever.to_dict() == {
"type": "haystack_integrations.components.retrievers.chroma.retriever.ChromaQueryRetriever",
"type": "haystack_integrations.components.retrievers.chroma.retriever.ChromaQueryTextRetriever",
"init_parameters": {
"filters": {"foo": "bar"},
"top_k": 99,
Expand All @@ -29,7 +29,7 @@ def test_retriever_to_json(request):
@pytest.mark.integration
def test_retriever_from_json(request):
data = {
"type": "haystack_integrations.components.retrievers.chroma.retriever.ChromaQueryRetriever",
"type": "haystack_integrations.components.retrievers.chroma.retriever.ChromaQueryTextRetriever",
"init_parameters": {
"filters": {"bar": "baz"},
"top_k": 42,
Expand All @@ -40,7 +40,7 @@ def test_retriever_from_json(request):
},
},
}
retriever = ChromaQueryRetriever.from_dict(data)
retriever = ChromaQueryTextRetriever.from_dict(data)
assert retriever.document_store._collection_name == request.node.name
assert retriever.document_store._embedding_function == "HuggingFaceEmbeddingFunction"
assert retriever.document_store._embedding_function_params == {"api_key": "1234567890"}
Expand Down

0 comments on commit 9c069c9

Please sign in to comment.