Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support more collection settings when creating a new WeaviateDocumentStore #260

Merged
merged 3 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions integrations/weaviate/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
version: '3.4'
services:
weaviate:
command:
- --host
- 0.0.0.0
- --port
- '8080'
- --scheme
- http
image: semitechnologies/weaviate:1.23.2
ports:
- 8080:8080
- 50051:50051
restart: on-failure:0
environment:
QUERY_DEFAULTS_LIMIT: 25
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true'
PERSISTENCE_DATA_PATH: '/var/lib/weaviate'
DEFAULT_VECTORIZER_MODULE: 'none'
ENABLE_MODULES: ''
CLUSTER_HOSTNAME: 'node1'
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.dataclasses.document import Document
from haystack.document_stores.protocol import DuplicatePolicy
from haystack.document_stores.types.policy import DuplicatePolicy

import weaviate
from weaviate.auth import AuthCredentials
Expand All @@ -25,6 +25,20 @@
"weaviate.auth.AuthApiKey": weaviate.auth.AuthApiKey,
}

# This is the default collection properties for Weaviate.
# It's a list of properties that will be created on the collection.
# These are extremely similar to the Document dataclass, but with a few differences:
# - `id` is renamed to `_original_id` as the `id` field is reserved by Weaviate.
# - `blob` is split into `blob_data` and `blob_mime_type` as it's more efficient to store them separately.
DOCUMENT_COLLECTION_PROPERTIES = [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
silvanocerza marked this conversation as resolved.
Show resolved Hide resolved
{"name": "score", "dataType": ["number"]},
]

anakin87 marked this conversation as resolved.
Show resolved Hide resolved

class WeaviateDocumentStore:
"""
Expand All @@ -35,7 +49,7 @@ def __init__(
self,
*,
url: Optional[str] = None,
collection_name: str = "default",
collection_settings: Optional[Dict[str, Any]] = None,
auth_client_secret: Optional[AuthCredentials] = None,
timeout_config: TimeoutType = (10, 60),
proxies: Optional[Union[Dict, str]] = None,
Expand All @@ -49,6 +63,16 @@ def __init__(
Create a new instance of WeaviateDocumentStore and connects to the Weaviate instance.

:param url: The URL to the weaviate instance, defaults to None.
:param collection_settings: The collection settings to use, defaults to None.
If None it will use a collection named `default` with the following properties:
- _original_id: text
- content: text
- dataframe: text
- blob_data: blob
- blob_mime_type: text
- score: number
See the official `Weaviate documentation<https://weaviate.io/developers/weaviate/manage-data/collections>`_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just curious about this format for links...
Is it Sphinx? Is it supported in ReadMe?
Should it be Link text <link URL>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be handle by pydoc when converting the docstrings to Markdown.

In any case we need to standardise the docstrings sooner than later.

for more information on collections.
:param auth_client_secret: Authentication credentials, defaults to None.
Can be one of the following types depending on the authentication mode:
- `weaviate.auth.AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens
Expand Down Expand Up @@ -80,8 +104,6 @@ def __init__(
:param embedded_options: If set create an embedded Weaviate cluster inside the client, defaults to None.
For a full list of options see `weaviate.embedded.EmbeddedOptions`.
:param additional_config: Additional and advanced configuration options for weaviate, defaults to None.
:param collection_name: The name of the collection to use, defaults to "default".
If the collection does not exist it will be created.
"""
self._client = weaviate.Client(
url=url,
Expand All @@ -98,11 +120,22 @@ def __init__(
# Test connection, it will raise an exception if it fails.
self._client.schema.get()

if not self._client.schema.exists(collection_name):
self._client.schema.create_class({"class": collection_name})
if collection_settings is None:
collection_settings = {
"class": "Default",
"properties": DOCUMENT_COLLECTION_PROPERTIES,
}
else:
# Set the class if not set
collection_settings["class"] = collection_settings.get("class", "default").capitalize()
# Set the properties if they're not set
collection_settings["properties"] = collection_settings.get("properties", DOCUMENT_COLLECTION_PROPERTIES)

if not self._client.schema.exists(collection_settings["class"]):
self._client.schema.create_class(collection_settings)

self._url = url
self._collection_name = collection_name
self._collection_settings = collection_settings
self._auth_client_secret = auth_client_secret
self._timeout_config = timeout_config
self._proxies = proxies
Expand All @@ -124,7 +157,7 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
url=self._url,
collection_name=self._collection_name,
collection_settings=self._collection_settings,
auth_client_secret=auth_client_secret,
timeout_config=self._timeout_config,
proxies=self._proxies,
Expand Down Expand Up @@ -161,7 +194,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
return []

def write_documents(
self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE # noqa: ARG002
self,
documents: List[Document], # noqa: ARG002
policy: DuplicatePolicy = DuplicatePolicy.NONE, # noqa: ARG002
) -> int:
return 0

Expand Down
52 changes: 44 additions & 8 deletions integrations/weaviate/tests/test_document_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from unittest.mock import MagicMock, patch

from haystack_integrations.document_stores.weaviate.document_store import WeaviateDocumentStore
import pytest
from haystack_integrations.document_stores.weaviate.document_store import (
DOCUMENT_COLLECTION_PROPERTIES,
WeaviateDocumentStore,
)
from weaviate.auth import AuthApiKey
from weaviate.config import Config
from weaviate.embedded import (
Expand All @@ -13,6 +17,17 @@


class TestWeaviateDocumentStore:
@pytest.fixture
def document_store(self, request) -> WeaviateDocumentStore:
# Use a different index for each test so we can run them in parallel
collection_settings = {"class": f"{request.node.name}"}
store = WeaviateDocumentStore(
url="http://localhost:8080",
collection_settings=collection_settings,
)
yield store
store._client.schema.delete_class(collection_settings["class"])

@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.Client")
def test_init(self, mock_weaviate_client_class):
mock_client = MagicMock()
Expand All @@ -21,7 +36,7 @@ def test_init(self, mock_weaviate_client_class):

WeaviateDocumentStore(
url="http://localhost:8080",
collection_name="my_collection",
collection_settings={"class": "My_collection"},
auth_client_secret=AuthApiKey("my_api_key"),
proxies={"http": "http://proxy:1234"},
additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
Expand Down Expand Up @@ -54,14 +69,15 @@ def test_init(self, mock_weaviate_client_class):

# Verify collection is created
mock_client.schema.get.assert_called_once()
mock_client.schema.exists.assert_called_once_with("my_collection")
mock_client.schema.create_class.assert_called_once_with({"class": "my_collection"})
mock_client.schema.exists.assert_called_once_with("My_collection")
mock_client.schema.create_class.assert_called_once_with(
{"class": "My_collection", "properties": DOCUMENT_COLLECTION_PROPERTIES}
)

@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate")
def test_to_dict(self, _mock_weaviate):
document_store = WeaviateDocumentStore(
url="http://localhost:8080",
collection_name="my_collection",
auth_client_secret=AuthApiKey("my_api_key"),
proxies={"http": "http://proxy:1234"},
additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
Expand All @@ -77,7 +93,17 @@ def test_to_dict(self, _mock_weaviate):
"type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": "http://localhost:8080",
"collection_name": "my_collection",
"collection_settings": {
"class": "Default",
"properties": [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
{"name": "score", "dataType": ["number"]},
],
},
"auth_client_secret": {
"type": "weaviate.auth.AuthApiKey",
"init_parameters": {"api_key": "my_api_key"},
Expand Down Expand Up @@ -113,7 +139,7 @@ def test_from_dict(self, _mock_weaviate):
"type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": "http://localhost:8080",
"collection_name": "my_collection",
"collection_settings": None,
"auth_client_secret": {
"type": "weaviate.auth.AuthApiKey",
"init_parameters": {"api_key": "my_api_key"},
Expand Down Expand Up @@ -144,7 +170,17 @@ def test_from_dict(self, _mock_weaviate):
)

assert document_store._url == "http://localhost:8080"
assert document_store._collection_name == "my_collection"
assert document_store._collection_settings == {
"class": "Default",
"properties": [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
{"name": "score", "dataType": ["number"]},
],
}
assert document_store._auth_client_secret == AuthApiKey("my_api_key")
assert document_store._timeout_config == (10, 60)
assert document_store._proxies == {"http": "http://proxy:1234"}
Expand Down