Skip to content

Commit

Permalink
Accept more collection settings when initializing WeaviateDocumentStore
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza committed Jan 24, 2024
1 parent 9c5a63e commit db919ad
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 18 deletions.
2 changes: 1 addition & 1 deletion integrations/weaviate/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ services:
PERSISTENCE_DATA_PATH: '/var/lib/weaviate'
DEFAULT_VECTORIZER_MODULE: 'none'
ENABLE_MODULES: ''
CLUSTER_HOSTNAME: 'node1'
CLUSTER_HOSTNAME: 'node1'
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.dataclasses.document import Document
from haystack.document_stores.protocol import DuplicatePolicy
from haystack.document_stores.types.policy import DuplicatePolicy

import weaviate
from weaviate.auth import AuthCredentials
Expand All @@ -25,6 +25,20 @@
"weaviate.auth.AuthApiKey": weaviate.auth.AuthApiKey,
}

# This is the default collection properties for Weaviate.
# It's a list of properties that will be created on the collection.
# These are extremely similar to the Document dataclass, but with a few differences:
# - `id` is renamed to `_original_id` as the `id` field is reserved by Weaviate.
# - `blob` is split into `blob_data` and `blob_mime_type` as it's more efficient to store them separately.
DOCUMENT_COLLECTION_PROPERTIES = [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
{"name": "score", "dataType": ["number"]},
]


class WeaviateDocumentStore:
"""
Expand All @@ -35,7 +49,7 @@ def __init__(
self,
*,
url: Optional[str] = None,
collection_name: str = "default",
collection_settings: Optional[Dict[str, Any]] = None,
auth_client_secret: Optional[AuthCredentials] = None,
timeout_config: TimeoutType = (10, 60),
proxies: Optional[Union[Dict, str]] = None,
Expand All @@ -49,6 +63,16 @@ def __init__(
Create a new instance of WeaviateDocumentStore and connects to the Weaviate instance.
:param url: The URL to the weaviate instance, defaults to None.
:param collection_settings: The collection settings to use, defaults to None.
If None it will use a collection named `default` with the following properties:
- _original_id: text
- content: text
- dataframe: text
- blob_data: blob
- blob_mime_type: text
- score: number
See the official `Weaviate documentation<https://weaviate.io/developers/weaviate/manage-data/collections>`_
for more information on collections.
:param auth_client_secret: Authentication credentials, defaults to None.
Can be one of the following types depending on the authentication mode:
- `weaviate.auth.AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens
Expand Down Expand Up @@ -80,8 +104,6 @@ def __init__(
:param embedded_options: If set create an embedded Weaviate cluster inside the client, defaults to None.
For a full list of options see `weaviate.embedded.EmbeddedOptions`.
:param additional_config: Additional and advanced configuration options for weaviate, defaults to None.
:param collection_name: The name of the collection to use, defaults to "default".
If the collection does not exist it will be created.
"""
self._client = weaviate.Client(
url=url,
Expand All @@ -98,11 +120,22 @@ def __init__(
# Test connection, it will raise an exception if it fails.
self._client.schema.get()

if not self._client.schema.exists(collection_name):
self._client.schema.create_class({"class": collection_name})
if collection_settings is None:
collection_settings = {
"class": "Default",
"properties": DOCUMENT_COLLECTION_PROPERTIES,
}
else:
# Set the class if not set
collection_settings["class"] = collection_settings.get("class", "default").capitalize()
# Set the properties if they're not set
collection_settings["properties"] = collection_settings.get("properties", DOCUMENT_COLLECTION_PROPERTIES)

if not self._client.schema.exists(collection_settings["class"]):
self._client.schema.create_class(collection_settings)

self._url = url
self._collection_name = collection_name
self._collection_settings = collection_settings
self._auth_client_secret = auth_client_secret
self._timeout_config = timeout_config
self._proxies = proxies
Expand All @@ -124,7 +157,7 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
url=self._url,
collection_name=self._collection_name,
collection_settings=self._collection_settings,
auth_client_secret=auth_client_secret,
timeout_config=self._timeout_config,
proxies=self._proxies,
Expand Down Expand Up @@ -161,7 +194,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
return []

def write_documents(
self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE # noqa: ARG002
self,
documents: List[Document],
policy: DuplicatePolicy = DuplicatePolicy.NONE, # noqa: ARG002
) -> int:
return 0

Expand Down
52 changes: 44 additions & 8 deletions integrations/weaviate/tests/test_document_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from unittest.mock import MagicMock, patch

from haystack_integrations.document_stores.weaviate.document_store import WeaviateDocumentStore
import pytest
from haystack_integrations.document_stores.weaviate.document_store import (
DOCUMENT_COLLECTION_PROPERTIES,
WeaviateDocumentStore,
)
from weaviate.auth import AuthApiKey
from weaviate.config import Config
from weaviate.embedded import (
Expand All @@ -13,6 +17,17 @@


class TestWeaviateDocumentStore:
@pytest.fixture
def document_store(self, request) -> WeaviateDocumentStore:
# Use a different index for each test so we can run them in parallel
collection_settings = {"class": f"{request.node.name}"}
store = WeaviateDocumentStore(
url="http://localhost:8080",
collection_settings=collection_settings,
)
yield store
store._client.schema.delete_class(collection_settings["class"])

@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.Client")
def test_init(self, mock_weaviate_client_class):
mock_client = MagicMock()
Expand All @@ -21,7 +36,7 @@ def test_init(self, mock_weaviate_client_class):

WeaviateDocumentStore(
url="http://localhost:8080",
collection_name="my_collection",
collection_settings={"class": "My_collection"},
auth_client_secret=AuthApiKey("my_api_key"),
proxies={"http": "http://proxy:1234"},
additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
Expand Down Expand Up @@ -54,14 +69,15 @@ def test_init(self, mock_weaviate_client_class):

# Verify collection is created
mock_client.schema.get.assert_called_once()
mock_client.schema.exists.assert_called_once_with("my_collection")
mock_client.schema.create_class.assert_called_once_with({"class": "my_collection"})
mock_client.schema.exists.assert_called_once_with("My_collection")
mock_client.schema.create_class.assert_called_once_with(
{"class": "My_collection", "properties": DOCUMENT_COLLECTION_PROPERTIES}
)

@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate")
def test_to_dict(self, _mock_weaviate):
document_store = WeaviateDocumentStore(
url="http://localhost:8080",
collection_name="my_collection",
auth_client_secret=AuthApiKey("my_api_key"),
proxies={"http": "http://proxy:1234"},
additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
Expand All @@ -77,7 +93,17 @@ def test_to_dict(self, _mock_weaviate):
"type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": "http://localhost:8080",
"collection_name": "my_collection",
"collection_settings": {
"class": "Default",
"properties": [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
{"name": "score", "dataType": ["number"]},
],
},
"auth_client_secret": {
"type": "weaviate.auth.AuthApiKey",
"init_parameters": {"api_key": "my_api_key"},
Expand Down Expand Up @@ -113,7 +139,7 @@ def test_from_dict(self, _mock_weaviate):
"type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": "http://localhost:8080",
"collection_name": "my_collection",
"collection_settings": None,
"auth_client_secret": {
"type": "weaviate.auth.AuthApiKey",
"init_parameters": {"api_key": "my_api_key"},
Expand Down Expand Up @@ -144,7 +170,17 @@ def test_from_dict(self, _mock_weaviate):
)

assert document_store._url == "http://localhost:8080"
assert document_store._collection_name == "my_collection"
assert document_store._collection_settings == {
"class": "Default",
"properties": [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
{"name": "score", "dataType": ["number"]},
],
}
assert document_store._auth_client_secret == AuthApiKey("my_api_key")
assert document_store._timeout_config == (10, 60)
assert document_store._proxies == {"http": "http://proxy:1234"}
Expand Down

0 comments on commit db919ad

Please sign in to comment.