Skip to content

Commit

Permalink
docs: review chroma integration (#501)
Browse files Browse the repository at this point in the history
* docs: review chroma integration

* docs: drop unused exception

* docs: drop duplicate empty line
  • Loading branch information
wochinge authored Feb 29, 2024
1 parent e269dd8 commit a631a47
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,42 @@
@component
class ChromaQueryTextRetriever:
"""
A component for retrieving documents from an ChromaDocumentStore using the `query` API.
A component for retrieving documents from a [Chroma database](https://docs.trychroma.com/) using the `query` API.
Example usage:
```python
from haystack import Pipeline
from haystack.components.converters import TextFileToDocument
from haystack.components.writers import DocumentWriter
from haystack_integrations.document_stores.chroma import ChromaDocumentStore
from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever
file_paths = ...
# Chroma is used in-memory so we use the same instances in the two pipelines below
document_store = ChromaDocumentStore()
indexing = Pipeline()
indexing.add_component("converter", TextFileToDocument())
indexing.add_component("writer", DocumentWriter(document_store))
indexing.connect("converter", "writer")
indexing.run({"converter": {"sources": file_paths}})
querying = Pipeline()
querying.add_component("retriever", ChromaQueryTextRetriever(document_store))
results = querying.run({"retriever": {"query": "Variable declarations", "top_k": 3}})
for d in results["retriever"]["documents"]:
print(d.meta, d.score)
```
"""

def __init__(self, document_store: ChromaDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10):
"""
Create a ChromaQueryTextRetriever component.
:param document_store: An instance of ChromaDocumentStore.
:param filters: A dictionary with filters to narrow down the search space (default is None).
:param top_k: The maximum number of documents to retrieve (default is 10).
:param document_store: an instance of `ChromaDocumentStore`.
:param filters: filters to narrow down the search space.
:param top_k: the maximum number of documents to retrieve.
"""
self.filters = filters
self.top_k = top_k
Expand All @@ -36,7 +62,10 @@ def run(
Run the retriever on the given input data.
:param query: The input data for the retriever. In this case, a plain-text query.
:return: The retrieved documents.
:param top_k: The maximum number of documents to retrieve.
If not specified, the default value from the constructor is used.
:return: A dictionary with the following keys:
- "documents": List of documents returned by the search engine.
:raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance.
"""
Expand All @@ -46,7 +75,10 @@ def run(

def to_dict(self) -> Dict[str, Any]:
"""
Override the default serializer in order to manage the Chroma client string representation
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
d = default_to_dict(self, filters=self.filters, top_k=self.top_k)
d["init_parameters"]["document_store"] = self.document_store.to_dict()
Expand All @@ -55,13 +87,25 @@ def to_dict(self) -> Dict[str, Any]:

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
document_store = ChromaDocumentStore.from_dict(data["init_parameters"]["document_store"])
data["init_parameters"]["document_store"] = document_store
return default_from_dict(cls, data)


@component
class ChromaEmbeddingRetriever(ChromaQueryTextRetriever):
"""
A component for retrieving documents from a [Chroma database](https://docs.trychroma.com/) using embeddings.
"""

@component.output_types(documents=List[Document])
def run(
self,
Expand All @@ -72,10 +116,9 @@ def run(
"""
Run the retriever on the given input data.
:param queries: The input data for the retriever. In this case, a list of queries.
:return: The retrieved documents.
:raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance.
:param query_embedding: the query embeddings.
:return: a dictionary with the following keys:
- "documents": List of documents returned by the search engine.
"""
top_k = top_k or self.top_k

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

class ChromaDocumentStore:
"""
A document store using [Chroma](https://docs.trychroma.com/) as the backend.
We use the `collection.get` API to implement the document store protocol,
the `collection.search` API will be used in the retriever instead.
"""
Expand All @@ -38,6 +40,11 @@ def __init__(
Note: for the component to be part of a serializable pipeline, the __init__
parameters must be serializable, reason why we use a registry to configure the
embedding function passing a string.
:param collection_name: the name of the collection to use in the database.
:param embedding_function: the name of the embedding function to use to embed the query
:param persist_path: where to store the database. If None, the database will be `in-memory`.
:param embedding_function_params: additional parameters to pass to the embedding function.
"""
# Store the params for marshalling
self._collection_name = collection_name
Expand All @@ -56,6 +63,8 @@ def __init__(
def count_documents(self) -> int:
"""
Returns how many documents are present in the document store.
:returns: how many documents are present in the document store.
"""
return self._collection.count()

Expand Down Expand Up @@ -128,7 +137,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
```
:param filters: the filters to apply to the document list.
:return: a list of Documents that match the given filters.
:returns: a list of Documents that match the given filters.
"""
if filters:
ids, where, where_document = self._normalize_filters(filters)
Expand All @@ -152,7 +161,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
:param documents: a list of documents.
:param policy: not supported at the moment
:raises DuplicateDocumentError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL`
:return: None
:returns: None
"""
for doc in documents:
if not isinstance(doc, Document):
Expand All @@ -177,15 +186,17 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
def delete_documents(self, document_ids: List[str]) -> None:
"""
Deletes all documents with a matching document_ids from the document store.
Fails with `MissingDocumentError` if no document with this id is present in the store.
:param document_ids: the object_ids to delete
"""
self._collection.delete(ids=document_ids)

def search(self, queries: List[str], top_k: int) -> List[List[Document]]:
"""
Perform vector search on the stored documents
"""Search the documents in the store using the provided text queries.
:param queries: the list of queries to search for.
:param top_k: top_k documents to return for each query.
:return: matching documents for each query.
"""
results = self._collection.query(
query_texts=queries, n_results=top_k, include=["embeddings", "documents", "metadatas", "distances"]
Expand All @@ -196,10 +207,14 @@ def search_embeddings(
self, query_embeddings: List[List[float]], top_k: int, filters: Optional[Dict[str, Any]] = None
) -> List[List[Document]]:
"""
Perform vector search on the stored document, pass the embeddings of the queries
instead of their text.
Perform vector search on the stored document, pass the embeddings of the queries instead of their text.
:param query_embeddings: a list of embeddings to use as queries.
:param top_k: the maximum number of documents to retrieve.
:param filters: a dictionary of filters to apply to the search. Accepts filters in haystack format.
:returns: a list of lists of documents that match the given filters.
Accepts filters in haystack format.
"""
if filters is None:
results = self._collection.query(
Expand All @@ -221,9 +236,23 @@ def search_embeddings(

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChromaDocumentStore":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
return ChromaDocumentStore(**data)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return {
"collection_name": self._collection_name,
"embedding_function": self._embedding_function,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@


class ChromaDocumentStoreError(DocumentStoreError):
"""Parent class for all ChromaDocumentStore exceptions."""

pass


class ChromaDocumentStoreFilterError(FilterError, ValueError):
"""Raised when a filter is not valid for a ChromaDocumentStore."""

pass


class ChromaDocumentStoreConfigError(ChromaDocumentStoreError):
"""Raised when a configuration is not valid for a ChromaDocumentStore."""

pass
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@


def get_embedding_function(function_name: str, **kwargs) -> EmbeddingFunction:
"""Load an embedding function by name.
:param function_name: the name of the embedding function.
:param kwargs: additional arguments to pass to the embedding function.
:returns: the loaded embedding function.
:raises ChromaDocumentStoreConfigError: if the function name is invalid.
"""
try:
return FUNCTION_REGISTRY[function_name](**kwargs)
except KeyError:
Expand Down

0 comments on commit a631a47

Please sign in to comment.