Skip to content

Commit

Permalink
Handle connection to WCS and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza committed Jun 24, 2024
1 parent 75bb792 commit f48802b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,22 +172,30 @@ def client(self):
if self._client:
return self._client

# proxies, timeout_config, trust_env are part of additional_config now
# startup_period has been removed
self._client = weaviate.WeaviateClient(
connection_params=(
weaviate.connect.base.ConnectionParams.from_url(
url=self._url, grpc_port=self._grpc_port, grpc_secure=self._grpc_secure
)
if self._url
else None
),
auth_client_secret=self._auth_client_secret.resolve_value() if self._auth_client_secret else None,
additional_config=self._additional_config,
additional_headers=self._additional_headers,
embedded_options=self._embedded_options,
skip_init_checks=False,
)
if self._url and self._url.startswith("http") and self._url.endswith(".weaviate.network"):
self._client = weaviate.connect_to_wcs(
self._url,
auth_credentials=self._auth_client_secret.resolve_value() if self._auth_client_secret else None,
headers=self._additional_headers,
additional_config=self._additional_config,
)
else:
# proxies, timeout_config, trust_env are part of additional_config now
# startup_period has been removed
self._client = weaviate.WeaviateClient(
connection_params=(
weaviate.connect.base.ConnectionParams.from_url(
url=self._url, grpc_port=self._grpc_port, grpc_secure=self._grpc_secure
)
if self._url
else None
),
auth_client_secret=self._auth_client_secret.resolve_value() if self._auth_client_secret else None,
additional_config=self._additional_config,
additional_headers=self._additional_headers,
embedded_options=self._embedded_options,
skip_init_checks=False,
)

self._client.connect()

Expand Down
25 changes: 23 additions & 2 deletions integrations/weaviate/tests/test_document_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import os
import random
from typing import List
from unittest.mock import MagicMock, patch
Expand All @@ -16,6 +17,7 @@
FilterDocumentsTest,
WriteDocumentsTest,
)
from haystack.utils.auth import Secret
from haystack_integrations.document_stores.weaviate.auth import AuthApiKey
from haystack_integrations.document_stores.weaviate.document_store import (
DOCUMENT_COLLECTION_PROPERTIES,
Expand All @@ -26,8 +28,6 @@
from numpy import float32 as np_float32
from pandas import DataFrame
from weaviate.collections.classes.data import DataObject

# from weaviate.auth import AuthApiKey as WeaviateAuthApiKey
from weaviate.config import AdditionalConfig, ConnectionConfig, Proxies, Timeout
from weaviate.embedded import (
DEFAULT_BINARY_PATH,
Expand Down Expand Up @@ -697,3 +697,24 @@ def test_schema_class_name_conversion_preserves_pascal_case(self):
collection_settings=collection_settings,
)
assert doc_score._collection_settings["class"] == "Lower_case_name"

@pytest.mark.skipif(
not os.environ.get("WEAVIATE_API_KEY", None) and not os.environ.get("WEAVIATE_CLOUD_CLUSTER_URL", None),
reason="Both WEAVIATE_API_KEY and WEAVIATE_CLOUD_CLUSTER_URL are not set. Skipping test.",
)
def test_connect_to_weaviate_cloud(self):
document_store = WeaviateDocumentStore(
url=os.environ.get("WEAVIATE_CLOUD_CLUSTER_URL"),
auth_client_secret=AuthApiKey(api_key=Secret.from_env_var("WEAVIATE_API_KEY")),
)
assert document_store.client

def test_connect_to_local(self):
document_store = WeaviateDocumentStore(
url="http://localhost:8080",
)
assert document_store.client

def test_connect_to_embedded(self):
document_store = WeaviateDocumentStore(embedded_options=EmbeddedOptions())
assert document_store.client

0 comments on commit f48802b

Please sign in to comment.