Skip to content

Commit

Permalink
Allow persitent docstore and filters
Browse files Browse the repository at this point in the history
  • Loading branch information
illorca-verbi committed Feb 5, 2024
1 parent 2b608af commit ca446f3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class ChromaEmbeddingRetriever(ChromaQueryRetriever):
def run(
self,
query_embedding: List[float],
_: Optional[Dict[str, Any]] = None, # filters not yet supported
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
):
"""
Expand All @@ -80,4 +80,4 @@ def run(
top_k = top_k or self.top_k

query_embeddings = [query_embedding]
return {"documents": self.document_store.search_embeddings(query_embeddings, top_k)[0]}
return {"documents": self.document_store.search_embeddings(query_embeddings, top_k, filters)[0]}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ChromaDocumentStore:
"""

def __init__(
self, collection_name: str = "documents", embedding_function: str = "default", **embedding_function_params
self, collection_name: str = "documents", embedding_function: str = "default", persist_path: str = None, **embedding_function_params
):
"""
Initializes the store. The __init__ constructor is not part of the Store Protocol
Expand All @@ -40,7 +40,7 @@ def __init__(
self._embedding_function = embedding_function
self._embedding_function_params = embedding_function_params
# Create the client instance
self._chroma_client = chromadb.Client()
self._chroma_client = chromadb.Client() if persist_path == None else chromadb.PersistentClient(path=persist_path)
self._collection = self._chroma_client.get_or_create_collection(
name=collection_name,
embedding_function=get_embedding_function(embedding_function, **embedding_function_params),
Expand Down Expand Up @@ -185,14 +185,23 @@ def search(self, queries: List[str], top_k: int) -> List[List[Document]]:
)
return self._query_result_to_documents(results)

def search_embeddings(self, query_embeddings: List[List[float]], top_k: int) -> List[List[Document]]:
def search_embeddings(self, query_embeddings: List[List[float]], top_k: int, filters: Optional[Dict[str, Any]] = None) -> List[List[Document]]:
"""
Perform vector search on the stored document, pass the embeddings of the queries
instead of their text
instead of their text.
Accepts filters in haystack format.
"""
if filters is not None:
chroma_filters = self._normalize_filters(filters=filters)
else:
chroma_filters = (None, None, None)

results = self._collection.query(
query_embeddings=query_embeddings,
n_results=top_k,
where=chroma_filters[1],
where_document=chroma_filters[2],
include=["embeddings", "documents", "metadatas", "distances"],
)
return self._query_result_to_documents(results)
Expand Down

0 comments on commit ca446f3

Please sign in to comment.