From a631a470723fa114f34ab68051011b21c5c4787d Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Thu, 29 Feb 2024 12:06:23 +0100 Subject: [PATCH] docs: review chroma integration (#501) * docs: review chroma integration * docs: drop unused exception * docs: drop duplicate empty line --- .../components/retrievers/chroma/retriever.py | 67 +++++++++++++++---- .../document_stores/chroma/document_store.py | 45 ++++++++++--- .../document_stores/chroma/errors.py | 6 ++ .../document_stores/chroma/utils.py | 7 ++ 4 files changed, 105 insertions(+), 20 deletions(-) diff --git a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py index 2712f8c17..e19d4acbe 100644 --- a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py +++ b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py @@ -10,16 +10,42 @@ @component class ChromaQueryTextRetriever: """ - A component for retrieving documents from an ChromaDocumentStore using the `query` API. + A component for retrieving documents from a [Chroma database](https://docs.trychroma.com/) using the `query` API. + + Example usage: + ```python + from haystack import Pipeline + from haystack.components.converters import TextFileToDocument + from haystack.components.writers import DocumentWriter + + from haystack_integrations.document_stores.chroma import ChromaDocumentStore + from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever + + file_paths = ... + + # Chroma is used in-memory so we use the same instances in the two pipelines below + document_store = ChromaDocumentStore() + + indexing = Pipeline() + indexing.add_component("converter", TextFileToDocument()) + indexing.add_component("writer", DocumentWriter(document_store)) + indexing.connect("converter", "writer") + indexing.run({"converter": {"sources": file_paths}}) + + querying = Pipeline() + querying.add_component("retriever", ChromaQueryTextRetriever(document_store)) + results = querying.run({"retriever": {"query": "Variable declarations", "top_k": 3}}) + + for d in results["retriever"]["documents"]: + print(d.meta, d.score) + ``` """ def __init__(self, document_store: ChromaDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10): """ - Create a ChromaQueryTextRetriever component. - - :param document_store: An instance of ChromaDocumentStore. - :param filters: A dictionary with filters to narrow down the search space (default is None). - :param top_k: The maximum number of documents to retrieve (default is 10). + :param document_store: an instance of `ChromaDocumentStore`. + :param filters: filters to narrow down the search space. + :param top_k: the maximum number of documents to retrieve. """ self.filters = filters self.top_k = top_k @@ -36,7 +62,10 @@ def run( Run the retriever on the given input data. :param query: The input data for the retriever. In this case, a plain-text query. - :return: The retrieved documents. + :param top_k: The maximum number of documents to retrieve. + If not specified, the default value from the constructor is used. + :return: A dictionary with the following keys: + - "documents": List of documents returned by the search engine. :raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance. """ @@ -46,7 +75,10 @@ def run( def to_dict(self) -> Dict[str, Any]: """ - Override the default serializer in order to manage the Chroma client string representation + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ d = default_to_dict(self, filters=self.filters, top_k=self.top_k) d["init_parameters"]["document_store"] = self.document_store.to_dict() @@ -55,6 +87,14 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ document_store = ChromaDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store return default_from_dict(cls, data) @@ -62,6 +102,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever": @component class ChromaEmbeddingRetriever(ChromaQueryTextRetriever): + """ + A component for retrieving documents from a [Chroma database](https://docs.trychroma.com/) using embeddings. + """ + @component.output_types(documents=List[Document]) def run( self, @@ -72,10 +116,9 @@ def run( """ Run the retriever on the given input data. - :param queries: The input data for the retriever. In this case, a list of queries. - :return: The retrieved documents. - - :raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance. + :param query_embedding: the query embeddings. + :return: a dictionary with the following keys: + - "documents": List of documents returned by the search engine. """ top_k = top_k or self.top_k diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 3249ec4f8..0da89c87e 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -19,6 +19,8 @@ class ChromaDocumentStore: """ + A document store using [Chroma](https://docs.trychroma.com/) as the backend. + We use the `collection.get` API to implement the document store protocol, the `collection.search` API will be used in the retriever instead. """ @@ -38,6 +40,11 @@ def __init__( Note: for the component to be part of a serializable pipeline, the __init__ parameters must be serializable, reason why we use a registry to configure the embedding function passing a string. + + :param collection_name: the name of the collection to use in the database. + :param embedding_function: the name of the embedding function to use to embed the query + :param persist_path: where to store the database. If None, the database will be `in-memory`. + :param embedding_function_params: additional parameters to pass to the embedding function. """ # Store the params for marshalling self._collection_name = collection_name @@ -56,6 +63,8 @@ def __init__( def count_documents(self) -> int: """ Returns how many documents are present in the document store. + + :returns: how many documents are present in the document store. """ return self._collection.count() @@ -128,7 +137,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc ``` :param filters: the filters to apply to the document list. - :return: a list of Documents that match the given filters. + :returns: a list of Documents that match the given filters. """ if filters: ids, where, where_document = self._normalize_filters(filters) @@ -152,7 +161,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D :param documents: a list of documents. :param policy: not supported at the moment :raises DuplicateDocumentError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL` - :return: None + :returns: None """ for doc in documents: if not isinstance(doc, Document): @@ -177,15 +186,17 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D def delete_documents(self, document_ids: List[str]) -> None: """ Deletes all documents with a matching document_ids from the document store. - Fails with `MissingDocumentError` if no document with this id is present in the store. :param document_ids: the object_ids to delete """ self._collection.delete(ids=document_ids) def search(self, queries: List[str], top_k: int) -> List[List[Document]]: - """ - Perform vector search on the stored documents + """Search the documents in the store using the provided text queries. + + :param queries: the list of queries to search for. + :param top_k: top_k documents to return for each query. + :return: matching documents for each query. """ results = self._collection.query( query_texts=queries, n_results=top_k, include=["embeddings", "documents", "metadatas", "distances"] @@ -196,10 +207,14 @@ def search_embeddings( self, query_embeddings: List[List[float]], top_k: int, filters: Optional[Dict[str, Any]] = None ) -> List[List[Document]]: """ - Perform vector search on the stored document, pass the embeddings of the queries - instead of their text. + Perform vector search on the stored document, pass the embeddings of the queries instead of their text. + + :param query_embeddings: a list of embeddings to use as queries. + :param top_k: the maximum number of documents to retrieve. + :param filters: a dictionary of filters to apply to the search. Accepts filters in haystack format. + + :returns: a list of lists of documents that match the given filters. - Accepts filters in haystack format. """ if filters is None: results = self._collection.query( @@ -221,9 +236,23 @@ def search_embeddings( @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ChromaDocumentStore": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ return ChromaDocumentStore(**data) def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ return { "collection_name": self._collection_name, "embedding_function": self._embedding_function, diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/errors.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/errors.py index aeb0230cd..7596a265f 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/errors.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/errors.py @@ -6,12 +6,18 @@ class ChromaDocumentStoreError(DocumentStoreError): + """Parent class for all ChromaDocumentStore exceptions.""" + pass class ChromaDocumentStoreFilterError(FilterError, ValueError): + """Raised when a filter is not valid for a ChromaDocumentStore.""" + pass class ChromaDocumentStoreConfigError(ChromaDocumentStoreError): + """Raised when a configuration is not valid for a ChromaDocumentStore.""" + pass diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/utils.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/utils.py index b1a354af4..08d6db618 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/utils.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/utils.py @@ -32,6 +32,13 @@ def get_embedding_function(function_name: str, **kwargs) -> EmbeddingFunction: + """Load an embedding function by name. + + :param function_name: the name of the embedding function. + :param kwargs: additional arguments to pass to the embedding function. + :returns: the loaded embedding function. + :raises ChromaDocumentStoreConfigError: if the function name is invalid. + """ try: return FUNCTION_REGISTRY[function_name](**kwargs) except KeyError: