Skip to content

Commit

Permalink
Merge branch 'main' into feature/ollama_streaming_support
Browse files Browse the repository at this point in the history
  • Loading branch information
sachinsachdeva authored Feb 6, 2024
2 parents a6a50e5 + c8c73b7 commit 84a4135
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class ChromaEmbeddingRetriever(ChromaQueryRetriever):
def run(
self,
query_embedding: List[float],
_: Optional[Dict[str, Any]] = None, # filters not yet supported
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
):
"""
Expand All @@ -80,4 +80,4 @@ def run(
top_k = top_k or self.top_k

query_embeddings = [query_embedding]
return {"documents": self.document_store.search_embeddings(query_embeddings, top_k)[0]}
return {"documents": self.document_store.search_embeddings(query_embeddings, top_k, filters)[0]}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ class ChromaDocumentStore:
"""

def __init__(
self, collection_name: str = "documents", embedding_function: str = "default", **embedding_function_params
self,
collection_name: str = "documents",
embedding_function: str = "default",
persist_path: Optional[str] = None,
**embedding_function_params,
):
"""
Initializes the store. The __init__ constructor is not part of the Store Protocol
Expand All @@ -40,7 +44,10 @@ def __init__(
self._embedding_function = embedding_function
self._embedding_function_params = embedding_function_params
# Create the client instance
self._chroma_client = chromadb.Client()
if persist_path is None:
self._chroma_client = chromadb.Client()
else:
self._chroma_client = chromadb.PersistentClient(path=persist_path)
self._collection = self._chroma_client.get_or_create_collection(
name=collection_name,
embedding_function=get_embedding_function(embedding_function, **embedding_function_params),
Expand Down Expand Up @@ -185,16 +192,31 @@ def search(self, queries: List[str], top_k: int) -> List[List[Document]]:
)
return self._query_result_to_documents(results)

def search_embeddings(self, query_embeddings: List[List[float]], top_k: int) -> List[List[Document]]:
def search_embeddings(
self, query_embeddings: List[List[float]], top_k: int, filters: Optional[Dict[str, Any]] = None
) -> List[List[Document]]:
"""
Perform vector search on the stored document, pass the embeddings of the queries
instead of their text
instead of their text.
Accepts filters in haystack format.
"""
results = self._collection.query(
query_embeddings=query_embeddings,
n_results=top_k,
include=["embeddings", "documents", "metadatas", "distances"],
)
if filters is None:
results = self._collection.query(
query_embeddings=query_embeddings,
n_results=top_k,
include=["embeddings", "documents", "metadatas", "distances"],
)
else:
chroma_filters = self._normalize_filters(filters=filters)
results = self._collection.query(
query_embeddings=query_embeddings,
n_results=top_k,
where=chroma_filters[1],
where_document=chroma_filters[2],
include=["embeddings", "documents", "metadatas", "distances"],
)

return self._query_result_to_documents(results)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class UpTrainEvaluator:
A component that uses the UpTrain framework to evaluate inputs against a specific metric.
The supported metrics are defined by :class:`UpTrainMetric`. The inputs of the component
metric-dependent. The output is a list of :class:`UpTrainEvaluatorOutput` objects, each
containing a single input and the result of the evaluation performed on it.
metric-dependent. The output is a nested list of evaluation results where each inner list
contains the results for a single input.
"""

_backend_metric: Union[Evals, ParametricEval]
Expand Down

0 comments on commit 84a4135

Please sign in to comment.