Skip to content

Commit

Permalink
Handle connection to WCS and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza committed Jun 24, 2024
1 parent ecd3b61 commit 3186fae
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,10 @@ def client(self):
if self._client:
return self._client

# This is a quick ugly fix to make sure that users can use the DocumentStore
# with Weaviate Cloud Services with no issues
if self._url and self._url.startswith("http") and self._url.endswith(".weaviate.network"):
self._client = weaviate.connect_to_wcs(
self._url,
auth_credentials=self._auth_client_secret.resolve_value(),
auth_credentials=self._auth_client_secret.resolve_value() if self._auth_client_secret else None,
headers=self._additional_headers,
additional_config=self._additional_config,
)
Expand Down
25 changes: 23 additions & 2 deletions integrations/weaviate/tests/test_document_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import os
import random
from typing import List
from unittest.mock import MagicMock, patch
Expand All @@ -16,6 +17,7 @@
FilterDocumentsTest,
WriteDocumentsTest,
)
from haystack.utils.auth import Secret
from haystack_integrations.document_stores.weaviate.auth import AuthApiKey
from haystack_integrations.document_stores.weaviate.document_store import (
DOCUMENT_COLLECTION_PROPERTIES,
Expand All @@ -26,8 +28,6 @@
from numpy import float32 as np_float32
from pandas import DataFrame
from weaviate.collections.classes.data import DataObject

# from weaviate.auth import AuthApiKey as WeaviateAuthApiKey
from weaviate.config import AdditionalConfig, ConnectionConfig, Proxies, Timeout
from weaviate.embedded import (
DEFAULT_BINARY_PATH,
Expand Down Expand Up @@ -697,3 +697,24 @@ def test_schema_class_name_conversion_preserves_pascal_case(self):
collection_settings=collection_settings,
)
assert doc_score._collection_settings["class"] == "Lower_case_name"

@pytest.mark.skipif(
not os.environ.get("WEAVIATE_API_KEY", None) and not os.environ.get("WEAVIATE_CLOUD_CLUSTER_URL", None),
reason="Both WEAVIATE_API_KEY and WEAVIATE_CLOUD_CLUSTER_URL are not set. Skipping test.",
)
def test_connect_to_weaviate_cloud(self):
document_store = WeaviateDocumentStore(
url=os.environ.get("WEAVIATE_CLOUD_CLUSTER_URL"),
auth_client_secret=AuthApiKey(api_key=Secret.from_env_var("WEAVIATE_API_KEY")),
)
assert document_store.client

def test_connect_to_local(self):
document_store = WeaviateDocumentStore(
url="http://localhost:8080",
)
assert document_store.client

def test_connect_to_embedded(self):
document_store = WeaviateDocumentStore(embedded_options=EmbeddedOptions())
assert document_store.client

0 comments on commit 3186fae

Please sign in to comment.