From 7127be63c9aa152b5139c1e826f7d345f912b178 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Mon, 1 Jul 2024 01:50:01 +0200 Subject: [PATCH] feat: added distance_function property to ChromadocumentStore (#817) * Added the distance metric property --------- Co-authored-by: Amna Mubashar Co-authored-by: Stefano Fiorucci --- .../document_stores/chroma/document_store.py | 44 ++++++++++++++++--- .../chroma/tests/test_document_store.py | 26 ++++++++++- integrations/chroma/tests/test_retriever.py | 2 + 3 files changed, 65 insertions(+), 7 deletions(-) diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 02acbe8dc..d39158db4 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple import chromadb import numpy as np @@ -18,6 +18,9 @@ logger = logging.getLogger(__name__) +VALID_DISTANCE_FUNCTIONS = "l2", "cosine", "ip" + + class ChromaDocumentStore: """ A document store using [Chroma](https://docs.trychroma.com/) as the backend. @@ -31,6 +34,7 @@ def __init__( collection_name: str = "documents", embedding_function: str = "default", persist_path: Optional[str] = None, + distance_function: Literal["l2", "cosine", "ip"] = "l2", **embedding_function_params, ): """ @@ -45,22 +49,51 @@ def __init__( :param collection_name: the name of the collection to use in the database. :param embedding_function: the name of the embedding function to use to embed the query :param persist_path: where to store the database. If None, the database will be `in-memory`. + :param distance_function: The distance metric for the embedding space. + - `"l2"` computes the Euclidean (straight-line) distance between vectors, + where smaller scores indicate more similarity. + - `"cosine"` computes the cosine similarity between vectors, + with higher scores indicating greater similarity. + - `"ip"` stands for inner product, where higher scores indicate greater similarity between vectors. + **Note**: `distance_function` can only be set during the creation of a collection. + To change the distance metric of an existing collection, consider cloning the collection. + :param embedding_function_params: additional parameters to pass to the embedding function. """ + + if distance_function not in VALID_DISTANCE_FUNCTIONS: + error_message = ( + f"Invalid distance_function: '{distance_function}' for the collection. " + f"Valid options are: {VALID_DISTANCE_FUNCTIONS}." + ) + raise ValueError(error_message) + # Store the params for marshalling self._collection_name = collection_name self._embedding_function = embedding_function self._embedding_function_params = embedding_function_params self._persist_path = persist_path + self._distance_function = distance_function # Create the client instance if persist_path is None: self._chroma_client = chromadb.Client() else: self._chroma_client = chromadb.PersistentClient(path=persist_path) - self._collection = self._chroma_client.get_or_create_collection( - name=collection_name, - embedding_function=get_embedding_function(embedding_function, **embedding_function_params), - ) + + embedding_func = get_embedding_function(embedding_function, **embedding_function_params) + metadata = {"hnsw:space": distance_function} + + if collection_name in [c.name for c in self._chroma_client.list_collections()]: + self._collection = self._chroma_client.get_collection(collection_name, embedding_function=embedding_func) + + if distance_function != self._collection.metadata["hnsw:space"]: + logger.warning("Collection already exists. The `distance_function` parameter will be ignored.") + else: + self._collection = self._chroma_client.create_collection( + name=collection_name, + metadata=metadata, + embedding_function=embedding_func, + ) def count_documents(self) -> int: """ @@ -290,6 +323,7 @@ def to_dict(self) -> Dict[str, Any]: collection_name=self._collection_name, embedding_function=self._embedding_function, persist_path=self._persist_path, + distance_function=self._distance_function, **self._embedding_function_params, ) diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index 742f3305e..4e1181ae2 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present John Doe # # SPDX-License-Identifier: Apache-2.0 +import logging import operator import uuid from typing import List @@ -104,6 +105,7 @@ def test_to_json(self, request): "embedding_function": "HuggingFaceEmbeddingFunction", "persist_path": None, "api_key": "1234567890", + "distance_function": "l2", }, } @@ -118,6 +120,7 @@ def test_from_json(self): "embedding_function": "HuggingFaceEmbeddingFunction", "persist_path": None, "api_key": "1234567890", + "distance_function": "l2", }, } @@ -128,8 +131,27 @@ def test_from_json(self): @pytest.mark.integration def test_same_collection_name_reinitialization(self): - ChromaDocumentStore("test_name") - ChromaDocumentStore("test_name") + ChromaDocumentStore("test_1") + ChromaDocumentStore("test_1") + + @pytest.mark.integration + def test_distance_metric_initialization(self): + store = ChromaDocumentStore("test_2", distance_function="cosine") + assert store._collection.metadata["hnsw:space"] == "cosine" + + with pytest.raises(ValueError): + ChromaDocumentStore("test_3", distance_function="jaccard") + + @pytest.mark.integration + def test_distance_metric_reinitialization(self, caplog): + store = ChromaDocumentStore("test_4", distance_function="cosine") + + with caplog.at_level(logging.WARNING): + new_store = ChromaDocumentStore("test_4", distance_function="ip") + + assert "Collection already exists. The `distance_function` parameter will be ignored." in caplog.text + assert store._collection.metadata["hnsw:space"] == "cosine" + assert new_store._collection.metadata["hnsw:space"] == "cosine" @pytest.mark.skip(reason="Filter on dataframe contents is not supported.") def test_filter_document_dataframe(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): diff --git a/integrations/chroma/tests/test_retriever.py b/integrations/chroma/tests/test_retriever.py index b430e5fda..4ee320351 100644 --- a/integrations/chroma/tests/test_retriever.py +++ b/integrations/chroma/tests/test_retriever.py @@ -24,6 +24,7 @@ def test_retriever_to_json(request): "embedding_function": "HuggingFaceEmbeddingFunction", "persist_path": None, "api_key": "1234567890", + "distance_function": "l2", }, }, }, @@ -44,6 +45,7 @@ def test_retriever_from_json(request): "embedding_function": "HuggingFaceEmbeddingFunction", "persist_path": ".", "api_key": "1234567890", + "distance_function": "l2", }, }, },