Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added distance_function property to ChromadocumentStore #817

Merged
merged 10 commits into from
Jun 30, 2024
Merged
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
collection_name: str = "documents",
embedding_function: str = "default",
persist_path: Optional[str] = None,
distance_function: Optional[str] = None,
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
**embedding_function_params,
):
"""
Expand All @@ -45,22 +46,41 @@ def __init__(
:param collection_name: the name of the collection to use in the database.
:param embedding_function: the name of the embedding function to use to embed the query
:param persist_path: where to store the database. If None, the database will be `in-memory`.
:param distance_function: distance metric for the embedding space. Valid options are
'l2', 'cosine' and 'ip'.
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
:param embedding_function_params: additional parameters to pass to the embedding function.
"""
# Store the params for marshalling
self._collection_name = collection_name
self._embedding_function = embedding_function
self._embedding_function_params = embedding_function_params
self._persist_path = persist_path
self._distance_function = distance_function
# Create the client instance
if persist_path is None:
self._chroma_client = chromadb.Client()
else:
self._chroma_client = chromadb.PersistentClient(path=persist_path)
self._collection = self._chroma_client.get_or_create_collection(
name=collection_name,
embedding_function=get_embedding_function(embedding_function, **embedding_function_params),
)

if distance_function:
if collection_name in [c.name for c in self._chroma_client.list_collections()]:
self._collection = self._chroma_client.get_collection(
collection_name,
embedding_function=get_embedding_function(embedding_function, **embedding_function_params),
)
logger.warning("Collection already exists. The 'distance_function' parameter will be ignored.")

else:
self._collection = self._chroma_client.create_collection(
name=collection_name,
metadata={"hnsw:space": distance_function},
embedding_function=get_embedding_function(embedding_function, **embedding_function_params),
)
else:
self._collection = self._chroma_client.get_or_create_collection(
name=collection_name,
embedding_function=get_embedding_function(embedding_function, **embedding_function_params),
)
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved

def count_documents(self) -> int:
"""
Expand Down
47 changes: 47 additions & 0 deletions integrations/chroma/tests/test_document_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2023-present John Doe <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import logging
import operator
import uuid
from typing import List
Expand Down Expand Up @@ -126,11 +127,57 @@ def test_from_json(self):
assert ds._embedding_function == function_name
assert ds._embedding_function_params == {"api_key": "1234567890"}

@pytest.fixture(scope="class")
def setup_document_stores(self, request):
collection_name, distance_function = request.param
doc_store = ChromaDocumentStore(collection_name=collection_name, distance_function=distance_function)

docs = [
Document(content="Cats are small domesticated carnivorous mammals with soft fur.", meta={"id": "A"}),
Document(content="Dogs are loyal and friendly animals often kept as pets.", meta={"id": "B"}),
Document(content="Birds have feathers, wings, and beaks, and most can fly.", meta={"id": "C"}),
Document(content="Fish live in water, have gills, and are often covered with scales.", meta={"id": "D"}),
Document(
content="The sun is a star at the center of the solar system, providing light and heat to Earth.",
meta={"id": "E"},
),
]

doc_store.write_documents(docs)
return doc_store

@pytest.mark.parametrize("setup_document_stores", [("doc_store_cosine", "cosine")], indirect=True)
def test_cosine_similarity(self, setup_document_stores):
doc_store_cosine = setup_document_stores
query = ["Stars are astronomical objects consisting of a luminous spheroid of plasma."]
results_cosine = doc_store_cosine.search(query, top_k=1)[0]

assert results_cosine[0].score == pytest.approx(0.47612541913986206, abs=1e-3)

@pytest.mark.parametrize("setup_document_stores", [("doc_store_l2", "l2")], indirect=True)
def test_l2_similarity(self, setup_document_stores):
doc_store_l2 = setup_document_stores
query = ["Stars are astronomical objects consisting of a luminous spheroid of plasma."]
results_l2 = doc_store_l2.search(query, top_k=1)[0]

assert results_l2[0].score == pytest.approx(0.9522517323493958, abs=1e-3)
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.integration
def test_same_collection_name_reinitialization(self):
ChromaDocumentStore("test_name")
ChromaDocumentStore("test_name")

@pytest.mark.integration
def test_distance_metric_initialization(self):
ChromaDocumentStore("test_name_2", distance_function="cosine")
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.integration
def test_distance_metric_reinitialization(self, caplog):
ChromaDocumentStore("test_name_3", distance_function="cosine")

with caplog.at_level(logging.WARNING):
ChromaDocumentStore("test_name_3", distance_function="l2")
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.skip(reason="Filter on dataframe contents is not supported.")
def test_filter_document_dataframe(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass
Expand Down
Loading