Skip to content

Commit

Permalink
feat: Add metadata parameter to ChromaDocumentStore. (#906)
Browse files Browse the repository at this point in the history
* feat: Add metadata parameter to ChromaDocumentStore.

* Update integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py

Co-authored-by: Stefano Fiorucci <[email protected]>

* Update integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py

Co-authored-by: Stefano Fiorucci <[email protected]>

* Update integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py

Co-authored-by: Stefano Fiorucci <[email protected]>

* test: update logging message in chroma document store tests

* style: Fix formatting

* test: Add test for logging messages when creating chroma collection with the same name.

* test: Fix logging message.

---------

Co-authored-by: Stefano Fiorucci <[email protected]>
  • Loading branch information
guillaumecherel and anakin87 authored Jul 16, 2024
1 parent 1b3b36e commit ecaeedd
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
embedding_function: str = "default",
persist_path: Optional[str] = None,
distance_function: Literal["l2", "cosine", "ip"] = "l2",
metadata: Optional[dict] = None,
**embedding_function_params,
):
"""
Expand All @@ -57,6 +58,9 @@ def __init__(
- `"ip"` stands for inner product, where higher scores indicate greater similarity between vectors.
**Note**: `distance_function` can only be set during the creation of a collection.
To change the distance metric of an existing collection, consider cloning the collection.
:param metadata: a dictionary of chromadb collection parameters passed directly to chromadb's client
method `create_collection`. If it contains the key `"hnsw:space"`, the value will take precedence over the
`distance_function` parameter above.
:param embedding_function_params: additional parameters to pass to the embedding function.
"""
Expand All @@ -81,13 +85,18 @@ def __init__(
self._chroma_client = chromadb.PersistentClient(path=persist_path)

embedding_func = get_embedding_function(embedding_function, **embedding_function_params)
metadata = {"hnsw:space": distance_function}

metadata = metadata or {}
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = distance_function

if collection_name in [c.name for c in self._chroma_client.list_collections()]:
self._collection = self._chroma_client.get_collection(collection_name, embedding_function=embedding_func)

if distance_function != self._collection.metadata["hnsw:space"]:
logger.warning("Collection already exists. The `distance_function` parameter will be ignored.")
if metadata != self._collection.metadata:
logger.warning(
"Collection already exists. The `distance_function` and `metadata` parameters will be ignored."
)
else:
self._collection = self._chroma_client.create_collection(
name=collection_name,
Expand Down
40 changes: 39 additions & 1 deletion integrations/chroma/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,48 @@ def test_distance_metric_reinitialization(self, caplog):
with caplog.at_level(logging.WARNING):
new_store = ChromaDocumentStore("test_4", distance_function="ip")

assert "Collection already exists. The `distance_function` parameter will be ignored." in caplog.text
assert (
"Collection already exists. The `distance_function` and `metadata` parameters will be ignored."
in caplog.text
)
assert store._collection.metadata["hnsw:space"] == "cosine"
assert new_store._collection.metadata["hnsw:space"] == "cosine"

@pytest.mark.integration
def test_metadata_initialization(self, caplog):
store = ChromaDocumentStore(
"test_5",
distance_function="cosine",
metadata={
"hnsw:space": "ip",
"hnsw:search_ef": 101,
"hnsw:construction_ef": 102,
"hnsw:M": 103,
},
)
assert store._collection.metadata["hnsw:space"] == "ip"
assert store._collection.metadata["hnsw:search_ef"] == 101
assert store._collection.metadata["hnsw:construction_ef"] == 102
assert store._collection.metadata["hnsw:M"] == 103

with caplog.at_level(logging.WARNING):
new_store = ChromaDocumentStore(
"test_5",
metadata={
"hnsw:space": "l2",
"hnsw:search_ef": 101,
"hnsw:construction_ef": 102,
"hnsw:M": 103,
},
)

assert (
"Collection already exists. The `distance_function` and `metadata` parameters will be ignored."
in caplog.text
)
assert store._collection.metadata["hnsw:space"] == "ip"
assert new_store._collection.metadata["hnsw:space"] == "ip"

@pytest.mark.skip(reason="Filter on dataframe contents is not supported.")
def test_filter_document_dataframe(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]):
pass
Expand Down

0 comments on commit ecaeedd

Please sign in to comment.