Skip to content

Commit

Permalink
feat: added distance_function property to ChromadocumentStore (#817)
Browse files Browse the repository at this point in the history
* Added the distance metric property
---------

Co-authored-by: Amna Mubashar <[email protected]>
Co-authored-by: Stefano Fiorucci <[email protected]>
  • Loading branch information
Amnah199 and anakin87 authored Jun 30, 2024
1 parent 6d8ce95 commit 7127be6
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple

import chromadb
import numpy as np
Expand All @@ -18,6 +18,9 @@
logger = logging.getLogger(__name__)


VALID_DISTANCE_FUNCTIONS = "l2", "cosine", "ip"


class ChromaDocumentStore:
"""
A document store using [Chroma](https://docs.trychroma.com/) as the backend.
Expand All @@ -31,6 +34,7 @@ def __init__(
collection_name: str = "documents",
embedding_function: str = "default",
persist_path: Optional[str] = None,
distance_function: Literal["l2", "cosine", "ip"] = "l2",
**embedding_function_params,
):
"""
Expand All @@ -45,22 +49,51 @@ def __init__(
:param collection_name: the name of the collection to use in the database.
:param embedding_function: the name of the embedding function to use to embed the query
:param persist_path: where to store the database. If None, the database will be `in-memory`.
:param distance_function: The distance metric for the embedding space.
- `"l2"` computes the Euclidean (straight-line) distance between vectors,
where smaller scores indicate more similarity.
- `"cosine"` computes the cosine similarity between vectors,
with higher scores indicating greater similarity.
- `"ip"` stands for inner product, where higher scores indicate greater similarity between vectors.
**Note**: `distance_function` can only be set during the creation of a collection.
To change the distance metric of an existing collection, consider cloning the collection.
:param embedding_function_params: additional parameters to pass to the embedding function.
"""

if distance_function not in VALID_DISTANCE_FUNCTIONS:
error_message = (
f"Invalid distance_function: '{distance_function}' for the collection. "
f"Valid options are: {VALID_DISTANCE_FUNCTIONS}."
)
raise ValueError(error_message)

# Store the params for marshalling
self._collection_name = collection_name
self._embedding_function = embedding_function
self._embedding_function_params = embedding_function_params
self._persist_path = persist_path
self._distance_function = distance_function
# Create the client instance
if persist_path is None:
self._chroma_client = chromadb.Client()
else:
self._chroma_client = chromadb.PersistentClient(path=persist_path)
self._collection = self._chroma_client.get_or_create_collection(
name=collection_name,
embedding_function=get_embedding_function(embedding_function, **embedding_function_params),
)

embedding_func = get_embedding_function(embedding_function, **embedding_function_params)
metadata = {"hnsw:space": distance_function}

if collection_name in [c.name for c in self._chroma_client.list_collections()]:
self._collection = self._chroma_client.get_collection(collection_name, embedding_function=embedding_func)

if distance_function != self._collection.metadata["hnsw:space"]:
logger.warning("Collection already exists. The `distance_function` parameter will be ignored.")
else:
self._collection = self._chroma_client.create_collection(
name=collection_name,
metadata=metadata,
embedding_function=embedding_func,
)

def count_documents(self) -> int:
"""
Expand Down Expand Up @@ -290,6 +323,7 @@ def to_dict(self) -> Dict[str, Any]:
collection_name=self._collection_name,
embedding_function=self._embedding_function,
persist_path=self._persist_path,
distance_function=self._distance_function,
**self._embedding_function_params,
)

Expand Down
26 changes: 24 additions & 2 deletions integrations/chroma/tests/test_document_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2023-present John Doe <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import logging
import operator
import uuid
from typing import List
Expand Down Expand Up @@ -104,6 +105,7 @@ def test_to_json(self, request):
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": None,
"api_key": "1234567890",
"distance_function": "l2",
},
}

Expand All @@ -118,6 +120,7 @@ def test_from_json(self):
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": None,
"api_key": "1234567890",
"distance_function": "l2",
},
}

Expand All @@ -128,8 +131,27 @@ def test_from_json(self):

@pytest.mark.integration
def test_same_collection_name_reinitialization(self):
ChromaDocumentStore("test_name")
ChromaDocumentStore("test_name")
ChromaDocumentStore("test_1")
ChromaDocumentStore("test_1")

@pytest.mark.integration
def test_distance_metric_initialization(self):
store = ChromaDocumentStore("test_2", distance_function="cosine")
assert store._collection.metadata["hnsw:space"] == "cosine"

with pytest.raises(ValueError):
ChromaDocumentStore("test_3", distance_function="jaccard")

@pytest.mark.integration
def test_distance_metric_reinitialization(self, caplog):
store = ChromaDocumentStore("test_4", distance_function="cosine")

with caplog.at_level(logging.WARNING):
new_store = ChromaDocumentStore("test_4", distance_function="ip")

assert "Collection already exists. The `distance_function` parameter will be ignored." in caplog.text
assert store._collection.metadata["hnsw:space"] == "cosine"
assert new_store._collection.metadata["hnsw:space"] == "cosine"

@pytest.mark.skip(reason="Filter on dataframe contents is not supported.")
def test_filter_document_dataframe(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
Expand Down
2 changes: 2 additions & 0 deletions integrations/chroma/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_retriever_to_json(request):
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": None,
"api_key": "1234567890",
"distance_function": "l2",
},
},
},
Expand All @@ -44,6 +45,7 @@ def test_retriever_from_json(request):
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": ".",
"api_key": "1234567890",
"distance_function": "l2",
},
},
},
Expand Down

0 comments on commit 7127be6

Please sign in to comment.