From ecaeedd3ab192bd62ad4e21a53e7201e1efac0ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillaume=20Ch=C3=A9rel?= Date: Tue, 16 Jul 2024 11:44:57 +0200 Subject: [PATCH] feat: Add metadata parameter to ChromaDocumentStore. (#906) * feat: Add metadata parameter to ChromaDocumentStore. * Update integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py Co-authored-by: Stefano Fiorucci * Update integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py Co-authored-by: Stefano Fiorucci * Update integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py Co-authored-by: Stefano Fiorucci * test: update logging message in chroma document store tests * style: Fix formatting * test: Add test for logging messages when creating chroma collection with the same name. * test: Fix logging message. --------- Co-authored-by: Stefano Fiorucci --- .../document_stores/chroma/document_store.py | 15 +++++-- .../chroma/tests/test_document_store.py | 40 ++++++++++++++++++- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 49cfced2e..937d841f8 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -35,6 +35,7 @@ def __init__( embedding_function: str = "default", persist_path: Optional[str] = None, distance_function: Literal["l2", "cosine", "ip"] = "l2", + metadata: Optional[dict] = None, **embedding_function_params, ): """ @@ -57,6 +58,9 @@ def __init__( - `"ip"` stands for inner product, where higher scores indicate greater similarity between vectors. **Note**: `distance_function` can only be set during the creation of a collection. To change the distance metric of an existing collection, consider cloning the collection. + :param metadata: a dictionary of chromadb collection parameters passed directly to chromadb's client + method `create_collection`. If it contains the key `"hnsw:space"`, the value will take precedence over the + `distance_function` parameter above. :param embedding_function_params: additional parameters to pass to the embedding function. """ @@ -81,13 +85,18 @@ def __init__( self._chroma_client = chromadb.PersistentClient(path=persist_path) embedding_func = get_embedding_function(embedding_function, **embedding_function_params) - metadata = {"hnsw:space": distance_function} + + metadata = metadata or {} + if "hnsw:space" not in metadata: + metadata["hnsw:space"] = distance_function if collection_name in [c.name for c in self._chroma_client.list_collections()]: self._collection = self._chroma_client.get_collection(collection_name, embedding_function=embedding_func) - if distance_function != self._collection.metadata["hnsw:space"]: - logger.warning("Collection already exists. The `distance_function` parameter will be ignored.") + if metadata != self._collection.metadata: + logger.warning( + "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." + ) else: self._collection = self._chroma_client.create_collection( name=collection_name, diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index 774096a15..223bfd704 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -164,10 +164,48 @@ def test_distance_metric_reinitialization(self, caplog): with caplog.at_level(logging.WARNING): new_store = ChromaDocumentStore("test_4", distance_function="ip") - assert "Collection already exists. The `distance_function` parameter will be ignored." in caplog.text + assert ( + "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." + in caplog.text + ) assert store._collection.metadata["hnsw:space"] == "cosine" assert new_store._collection.metadata["hnsw:space"] == "cosine" + @pytest.mark.integration + def test_metadata_initialization(self, caplog): + store = ChromaDocumentStore( + "test_5", + distance_function="cosine", + metadata={ + "hnsw:space": "ip", + "hnsw:search_ef": 101, + "hnsw:construction_ef": 102, + "hnsw:M": 103, + }, + ) + assert store._collection.metadata["hnsw:space"] == "ip" + assert store._collection.metadata["hnsw:search_ef"] == 101 + assert store._collection.metadata["hnsw:construction_ef"] == 102 + assert store._collection.metadata["hnsw:M"] == 103 + + with caplog.at_level(logging.WARNING): + new_store = ChromaDocumentStore( + "test_5", + metadata={ + "hnsw:space": "l2", + "hnsw:search_ef": 101, + "hnsw:construction_ef": 102, + "hnsw:M": 103, + }, + ) + + assert ( + "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." + in caplog.text + ) + assert store._collection.metadata["hnsw:space"] == "ip" + assert new_store._collection.metadata["hnsw:space"] == "ip" + @pytest.mark.skip(reason="Filter on dataframe contents is not supported.") def test_filter_document_dataframe(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass