From 979a8127a2805f03bdab8c92cdd829a95cb5079b Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Thu, 29 Feb 2024 17:51:47 +0100 Subject: [PATCH] fix: serialize the path to the local db (#506) * serialize the path to the local db * fix tests * fix tests --- integrations/chroma/.gitignore | 4 +-- integrations/chroma/pyproject.toml | 3 ++- .../components/retrievers/chroma/retriever.py | 26 ++++++++++--------- .../document_stores/chroma/document_store.py | 16 +++++++----- .../chroma/tests/test_document_store.py | 20 +++++++++++--- integrations/chroma/tests/test_retriever.py | 21 ++++++++++----- 6 files changed, 59 insertions(+), 31 deletions(-) diff --git a/integrations/chroma/.gitignore b/integrations/chroma/.gitignore index d1c340c1f..a3d827e06 100644 --- a/integrations/chroma/.gitignore +++ b/integrations/chroma/.gitignore @@ -58,8 +58,8 @@ cover/ # Django stuff: *.log local_settings.py -db.sqlite3 -db.sqlite3-journal +*.sqlite3 +*.sqlite3-journal # Flask stuff: instance/ diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index 2653c491f..f8265a1e7 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -189,6 +189,7 @@ module = [ "chromadb.*", "haystack.*", "haystack_integrations.*", - "pytest.*" + "pytest.*", + "numpy.*" ] ignore_missing_imports = true diff --git a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py index e19d4acbe..10f97f01f 100644 --- a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py +++ b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py @@ -73,18 +73,6 @@ def run( return {"documents": self.document_store.search([query], top_k)[0]} - def to_dict(self) -> Dict[str, Any]: - """ - Serializes the component to a dictionary. - - :returns: - Dictionary with serialized data. - """ - d = default_to_dict(self, filters=self.filters, top_k=self.top_k) - d["init_parameters"]["document_store"] = self.document_store.to_dict() - - return d - @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever": """ @@ -99,6 +87,20 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever": data["init_parameters"]["document_store"] = document_store return default_from_dict(cls, data) + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + document_store=self.document_store.to_dict(), + ) + @component class ChromaEmbeddingRetriever(ChromaQueryTextRetriever): diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 8bb90aa2e..1dbeddbd3 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -8,6 +8,7 @@ import chromadb import numpy as np from chromadb.api.types import GetResult, QueryResult, validate_where, validate_where_document +from haystack import default_from_dict, default_to_dict from haystack.dataclasses import Document from haystack.document_stores.types import DuplicatePolicy @@ -50,6 +51,7 @@ def __init__( self._collection_name = collection_name self._embedding_function = embedding_function self._embedding_function_params = embedding_function_params + self._persist_path = persist_path # Create the client instance if persist_path is None: self._chroma_client = chromadb.Client() @@ -252,20 +254,22 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChromaDocumentStore": :returns: Deserialized component. """ - return ChromaDocumentStore(**data) + return default_from_dict(cls, data) def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. :returns: - Dictionary with serialized data. + Dictionary with serialized data. """ - return { - "collection_name": self._collection_name, - "embedding_function": self._embedding_function, + return default_to_dict( + self, + collection_name=self._collection_name, + embedding_function=self._embedding_function, + persist_path=self._persist_path, **self._embedding_function_params, - } + ) @staticmethod def _normalize_filters(filters: Dict[str, Any]) -> Tuple[List[str], Dict[str, Any], Dict[str, Any]]: diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index 506920ac2..8d61e63ed 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -97,16 +97,28 @@ def test_to_json(self, request): ) ds_dict = ds.to_dict() assert ds_dict == { - "collection_name": request.node.name, - "embedding_function": "HuggingFaceEmbeddingFunction", - "api_key": "1234567890", + "type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore", + "init_parameters": { + "collection_name": "test_to_json", + "embedding_function": "HuggingFaceEmbeddingFunction", + "persist_path": None, + "api_key": "1234567890", + }, } @pytest.mark.integration def test_from_json(self): collection_name = "test_collection" function_name = "HuggingFaceEmbeddingFunction" - ds_dict = {"collection_name": collection_name, "embedding_function": function_name, "api_key": "1234567890"} + ds_dict = { + "type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore", + "init_parameters": { + "collection_name": "test_collection", + "embedding_function": "HuggingFaceEmbeddingFunction", + "persist_path": None, + "api_key": "1234567890", + }, + } ds = ChromaDocumentStore.from_dict(ds_dict) assert ds._collection_name == collection_name diff --git a/integrations/chroma/tests/test_retriever.py b/integrations/chroma/tests/test_retriever.py index 88969d725..b430e5fda 100644 --- a/integrations/chroma/tests/test_retriever.py +++ b/integrations/chroma/tests/test_retriever.py @@ -18,9 +18,13 @@ def test_retriever_to_json(request): "filters": {"foo": "bar"}, "top_k": 99, "document_store": { - "collection_name": request.node.name, - "embedding_function": "HuggingFaceEmbeddingFunction", - "api_key": "1234567890", + "type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore", + "init_parameters": { + "collection_name": "test_retriever_to_json", + "embedding_function": "HuggingFaceEmbeddingFunction", + "persist_path": None, + "api_key": "1234567890", + }, }, }, } @@ -34,9 +38,13 @@ def test_retriever_from_json(request): "filters": {"bar": "baz"}, "top_k": 42, "document_store": { - "collection_name": request.node.name, - "embedding_function": "HuggingFaceEmbeddingFunction", - "api_key": "1234567890", + "type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore", + "init_parameters": { + "collection_name": "test_retriever_from_json", + "embedding_function": "HuggingFaceEmbeddingFunction", + "persist_path": ".", + "api_key": "1234567890", + }, }, }, } @@ -44,5 +52,6 @@ def test_retriever_from_json(request): assert retriever.document_store._collection_name == request.node.name assert retriever.document_store._embedding_function == "HuggingFaceEmbeddingFunction" assert retriever.document_store._embedding_function_params == {"api_key": "1234567890"} + assert retriever.document_store._persist_path == "." assert retriever.filters == {"bar": "baz"} assert retriever.top_k == 42