From 7723ddee7c17356f1393dc15b48711c73ba0663a Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Mon, 24 Jun 2024 15:35:30 +0200 Subject: [PATCH] fix: fix connection to Weaviate Cloud Service (#624) * Fix connection to Weaviate Cloud Service * Handle connection to WCS and add tests * Add comment explaining why we use utility function --- .../weaviate/document_store.py | 43 ++++++++++++------- .../weaviate/tests/test_document_store.py | 25 ++++++++++- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 5bacac010..a332d0bf9 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -172,22 +172,33 @@ def client(self): if self._client: return self._client - # proxies, timeout_config, trust_env are part of additional_config now - # startup_period has been removed - self._client = weaviate.WeaviateClient( - connection_params=( - weaviate.connect.base.ConnectionParams.from_url( - url=self._url, grpc_port=self._grpc_port, grpc_secure=self._grpc_secure - ) - if self._url - else None - ), - auth_client_secret=self._auth_client_secret.resolve_value() if self._auth_client_secret else None, - additional_config=self._additional_config, - additional_headers=self._additional_headers, - embedded_options=self._embedded_options, - skip_init_checks=False, - ) + if self._url and self._url.startswith("http") and self._url.endswith(".weaviate.network"): + # We use this utility function instead of using WeaviateClient directly like in other cases + # otherwise we'd have to parse the URL to get some information about the connection. + # This utility function does all that for us. + self._client = weaviate.connect_to_wcs( + self._url, + auth_credentials=self._auth_client_secret.resolve_value() if self._auth_client_secret else None, + headers=self._additional_headers, + additional_config=self._additional_config, + ) + else: + # proxies, timeout_config, trust_env are part of additional_config now + # startup_period has been removed + self._client = weaviate.WeaviateClient( + connection_params=( + weaviate.connect.base.ConnectionParams.from_url( + url=self._url, grpc_port=self._grpc_port, grpc_secure=self._grpc_secure + ) + if self._url + else None + ), + auth_client_secret=self._auth_client_secret.resolve_value() if self._auth_client_secret else None, + additional_config=self._additional_config, + additional_headers=self._additional_headers, + embedded_options=self._embedded_options, + skip_init_checks=False, + ) self._client.connect() diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index a4d275a22..31ff6e7b3 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -1,4 +1,5 @@ import base64 +import os import random from typing import List from unittest.mock import MagicMock, patch @@ -16,6 +17,7 @@ FilterDocumentsTest, WriteDocumentsTest, ) +from haystack.utils.auth import Secret from haystack_integrations.document_stores.weaviate.auth import AuthApiKey from haystack_integrations.document_stores.weaviate.document_store import ( DOCUMENT_COLLECTION_PROPERTIES, @@ -26,8 +28,6 @@ from numpy import float32 as np_float32 from pandas import DataFrame from weaviate.collections.classes.data import DataObject - -# from weaviate.auth import AuthApiKey as WeaviateAuthApiKey from weaviate.config import AdditionalConfig, ConnectionConfig, Proxies, Timeout from weaviate.embedded import ( DEFAULT_BINARY_PATH, @@ -697,3 +697,24 @@ def test_schema_class_name_conversion_preserves_pascal_case(self): collection_settings=collection_settings, ) assert doc_score._collection_settings["class"] == "Lower_case_name" + + @pytest.mark.skipif( + not os.environ.get("WEAVIATE_API_KEY", None) and not os.environ.get("WEAVIATE_CLOUD_CLUSTER_URL", None), + reason="Both WEAVIATE_API_KEY and WEAVIATE_CLOUD_CLUSTER_URL are not set. Skipping test.", + ) + def test_connect_to_weaviate_cloud(self): + document_store = WeaviateDocumentStore( + url=os.environ.get("WEAVIATE_CLOUD_CLUSTER_URL"), + auth_client_secret=AuthApiKey(api_key=Secret.from_env_var("WEAVIATE_API_KEY")), + ) + assert document_store.client + + def test_connect_to_local(self): + document_store = WeaviateDocumentStore( + url="http://localhost:8080", + ) + assert document_store.client + + def test_connect_to_embedded(self): + document_store = WeaviateDocumentStore(embedded_options=EmbeddedOptions()) + assert document_store.client