diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 1b6e95155..9493d3fcf 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -139,49 +139,72 @@ def __init__( :param grpc_secure: Whether to use a secure channel for the underlying gRPC API. """ + self._url = url + self._auth_client_secret = auth_client_secret + self._additional_headers = additional_headers + self._embedded_options = embedded_options + self._additional_config = additional_config + self._grpc_port = grpc_port + self._grpc_secure = grpc_secure + self._client = None + self._collection = None + # Store the connection settings dictionary + self._collection_settings = collection_settings or { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": DOCUMENT_COLLECTION_PROPERTIES, + } + self._clean_connection_settings() + + def _clean_connection_settings(self): + # Set the class if not set + _class_name = self._collection_settings.get("class", "Default") + _class_name = _class_name[0].upper() + _class_name[1:] + self._collection_settings["class"] = _class_name + # Set the properties if they're not set + self._collection_settings["properties"] = self._collection_settings.get( + "properties", DOCUMENT_COLLECTION_PROPERTIES + ) + + @property + def client(self): + if self._client: + return self._client + # proxies, timeout_config, trust_env are part of additional_config now # startup_period has been removed self._client = weaviate.WeaviateClient( connection_params=( - weaviate.connect.base.ConnectionParams.from_url(url=url, grpc_port=grpc_port, grpc_secure=grpc_secure) - if url + weaviate.connect.base.ConnectionParams.from_url( + url=self._url, grpc_port=self._grpc_port, grpc_secure=self._grpc_secure + ) + if self._url else None ), - auth_client_secret=auth_client_secret.resolve_value() if auth_client_secret else None, - additional_config=additional_config, - additional_headers=additional_headers, - embedded_options=embedded_options, + auth_client_secret=self._auth_client_secret.resolve_value() if self._auth_client_secret else None, + additional_config=self._additional_config, + additional_headers=self._additional_headers, + embedded_options=self._embedded_options, skip_init_checks=False, ) + self._client.connect() # Test connection, it will raise an exception if it fails. self._client.collections._get_all(simple=True) + if not self._client.collections.exists(self._collection_settings["class"]): + self._client.collections.create_from_dict(self._collection_settings) - if collection_settings is None: - collection_settings = { - "class": "Default", - "invertedIndexConfig": {"indexNullState": True}, - "properties": DOCUMENT_COLLECTION_PROPERTIES, - } - else: - # Set the class if not set - _class_name = collection_settings.get("class", "Default") - _class_name = _class_name[0].upper() + _class_name[1:] - collection_settings["class"] = _class_name - # Set the properties if they're not set - collection_settings["properties"] = collection_settings.get("properties", DOCUMENT_COLLECTION_PROPERTIES) + return self._client - if not self._client.collections.exists(collection_settings["class"]): - self._client.collections.create_from_dict(collection_settings) + @property + def collection(self): + if self._collection: + return self._collection - self._url = url - self._collection_settings = collection_settings - self._auth_client_secret = auth_client_secret - self._additional_headers = additional_headers - self._embedded_options = embedded_options - self._additional_config = additional_config - self._collection = self._client.collections.get(collection_settings["class"]) + client = self.client + self._collection = client.collections.get(self._collection_settings["class"]) + return self._collection def to_dict(self) -> Dict[str, Any]: """ @@ -230,7 +253,7 @@ def count_documents(self) -> int: """ Returns the number of documents present in the DocumentStore. """ - total = self._collection.aggregate.over_all(total_count=True).total_count + total = self.collection.aggregate.over_all(total_count=True).total_count return total if total else 0 def _to_data_object(self, document: Document) -> Dict[str, Any]: @@ -302,16 +325,16 @@ def _to_document(self, data: DataObject[Dict[str, Any], None]) -> Document: return Document.from_dict(document_data) def _query(self) -> List[Dict[str, Any]]: - properties = [p.name for p in self._collection.config.get().properties] + properties = [p.name for p in self.collection.config.get().properties] try: - result = self._collection.iterator(include_vector=True, return_properties=properties) + result = self.collection.iterator(include_vector=True, return_properties=properties) except weaviate.exceptions.WeaviateQueryError as e: msg = f"Failed to query documents in Weaviate. Error: {e.message}" raise DocumentStoreError(msg) from e return result def _query_with_filters(self, filters: Dict[str, Any]) -> List[Dict[str, Any]]: - properties = [p.name for p in self._collection.config.get().properties] + properties = [p.name for p in self.collection.config.get().properties] # When querying with filters we need to paginate using limit and offset as using # a cursor with after is not possible. See the official docs: # https://weaviate.io/developers/weaviate/api/graphql/additional-operators#cursor-with-after @@ -327,7 +350,7 @@ def _query_with_filters(self, filters: Dict[str, Any]) -> List[Dict[str, Any]]: # Keep querying until we get all documents matching the filters while partial_result is None or len(partial_result.objects) == DEFAULT_QUERY_LIMIT: try: - partial_result = self._collection.query.fetch_objects( + partial_result = self.collection.query.fetch_objects( filters=convert_filters(filters), include_vector=True, limit=DEFAULT_QUERY_LIMIT, @@ -365,7 +388,7 @@ def _batch_write(self, documents: List[Document]) -> int: Raises in case of errors. """ - with self._client.batch.dynamic() as batch: + with self.client.batch.dynamic() as batch: for doc in documents: if not isinstance(doc, Document): msg = f"Expected a Document, got '{type(doc)}' instead." @@ -373,11 +396,11 @@ def _batch_write(self, documents: List[Document]) -> int: batch.add_object( properties=self._to_data_object(doc), - collection=self._collection.name, + collection=self.collection.name, uuid=generate_uuid5(doc.id), vector=doc.embedding, ) - if failed_objects := self._client.batch.failed_objects: + if failed_objects := self.client.batch.failed_objects: # We fallback to use the UUID if the _original_id is not present, this is just to be mapped_objects = {} for obj in failed_objects: @@ -413,12 +436,12 @@ def _write(self, documents: List[Document], policy: DuplicatePolicy) -> int: msg = f"Expected a Document, got '{type(doc)}' instead." raise ValueError(msg) - if policy == DuplicatePolicy.SKIP and self._collection.data.exists(uuid=generate_uuid5(doc.id)): + if policy == DuplicatePolicy.SKIP and self.collection.data.exists(uuid=generate_uuid5(doc.id)): # This Document already exists, we skip it continue try: - self._collection.data.insert( + self.collection.data.insert( uuid=generate_uuid5(doc.id), properties=self._to_data_object(doc), vector=doc.embedding, @@ -454,13 +477,13 @@ def delete_documents(self, document_ids: List[str]) -> None: :param document_ids: The object_ids to delete. """ weaviate_ids = [generate_uuid5(doc_id) for doc_id in document_ids] - self._collection.data.delete_many(where=weaviate.classes.query.Filter.by_id().contains_any(weaviate_ids)) + self.collection.data.delete_many(where=weaviate.classes.query.Filter.by_id().contains_any(weaviate_ids)) def _bm25_retrieval( self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None ) -> List[Document]: - properties = [p.name for p in self._collection.config.get().properties] - result = self._collection.query.bm25( + properties = [p.name for p in self.collection.config.get().properties] + result = self.collection.query.bm25( query=query, filters=convert_filters(filters) if filters else None, limit=top_k, @@ -484,8 +507,8 @@ def _embedding_retrieval( msg = "Can't use 'distance' and 'certainty' parameters together" raise ValueError(msg) - properties = [p.name for p in self._collection.config.get().properties] - result = self._collection.query.near_vector( + properties = [p.name for p in self.collection.config.get().properties] + result = self.collection.query.near_vector( near_vector=query_embedding, distance=distance, certainty=certainty, diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 59a6ed2e3..e68d4d63d 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -38,6 +38,12 @@ ) +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.WeaviateClient") +def test_init_is_lazy(_mock_client): + _ = WeaviateDocumentStore() + _mock_client.assert_not_called() + + @pytest.mark.integration class TestWeaviateDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest): @pytest.fixture @@ -57,7 +63,7 @@ def document_store(self, request) -> WeaviateDocumentStore: collection_settings=collection_settings, ) yield store - store._client.collections.delete(collection_settings["class"]) + store.client.collections.delete(collection_settings["class"]) @pytest.fixture def filterable_docs(self) -> List[Document]: @@ -150,12 +156,12 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do assert received_meta.get(key) == expected_meta.get(key) @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.WeaviateClient") - def test_init(self, mock_weaviate_client_class, monkeypatch): + def test_connection(self, mock_weaviate_client_class, monkeypatch): mock_client = MagicMock() mock_client.collections.exists.return_value = False mock_weaviate_client_class.return_value = mock_client monkeypatch.setenv("WEAVIATE_API_KEY", "my_api_key") - WeaviateDocumentStore( + ds = WeaviateDocumentStore( collection_settings={"class": "My_collection"}, auth_client_secret=AuthApiKey(), additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, @@ -170,8 +176,11 @@ def test_init(self, mock_weaviate_client_class, monkeypatch): ), ) - # Verify client is created with correct parameters + # Trigger the actual database connection by accessing the `client` property so we + # can assert the setup was good + _ = ds.client + # Verify client is created with correct parameters mock_weaviate_client_class.assert_called_once_with( auth_client_secret=AuthApiKey().resolve_value(), connection_params=None,